agentic-base 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. agentic_base-0.1.0/PKG-INFO +23 -0
  2. agentic_base-0.1.0/README.md +0 -0
  3. agentic_base-0.1.0/agentic_base/__init__.py +0 -0
  4. agentic_base-0.1.0/agentic_base/connector/__init__.py +71 -0
  5. agentic_base-0.1.0/agentic_base/connector/minio.py +46 -0
  6. agentic_base-0.1.0/agentic_base/connector/parquet.py +21 -0
  7. agentic_base-0.1.0/agentic_base/connector/settings.py +22 -0
  8. agentic_base-0.1.0/agentic_base/constants.py +1 -0
  9. agentic_base-0.1.0/agentic_base/dataset/__init__.py +155 -0
  10. agentic_base-0.1.0/agentic_base/dataset/mixins.py +114 -0
  11. agentic_base-0.1.0/agentic_base/dataset/polars.py +43 -0
  12. agentic_base-0.1.0/agentic_base/dataset/preprocess/__init__.py +28 -0
  13. agentic_base-0.1.0/agentic_base/dataset/preprocess/preprocess.py +98 -0
  14. agentic_base-0.1.0/agentic_base/dataset/preprocess/token_counter.py +37 -0
  15. agentic_base-0.1.0/agentic_base/dataset/settings.py +59 -0
  16. agentic_base-0.1.0/agentic_base/enums.py +11 -0
  17. agentic_base-0.1.0/agentic_base/evaluation/__init__.py +163 -0
  18. agentic_base-0.1.0/agentic_base/evaluation/async_batcher.py +79 -0
  19. agentic_base-0.1.0/agentic_base/evaluation/embedders.py +27 -0
  20. agentic_base-0.1.0/agentic_base/evaluation/evaluators/__init__.py +37 -0
  21. agentic_base-0.1.0/agentic_base/evaluation/evaluators/python/__init__.py +113 -0
  22. agentic_base-0.1.0/agentic_base/evaluation/evaluators/python/evaluators.py +72 -0
  23. agentic_base-0.1.0/agentic_base/evaluation/evaluators/python/scores.py +59 -0
  24. agentic_base-0.1.0/agentic_base/evaluation/processors.py +16 -0
  25. agentic_base-0.1.0/agentic_base/evaluation/rerankers.py +1 -0
  26. agentic_base-0.1.0/agentic_base/evaluation/retrievers/__init__.py +104 -0
  27. agentic_base-0.1.0/agentic_base/evaluation/retrievers/dense/__init__.py +50 -0
  28. agentic_base-0.1.0/agentic_base/evaluation/retrievers/dense/retrievers.py +86 -0
  29. agentic_base-0.1.0/agentic_base/evaluation/settings.py +176 -0
  30. agentic_base-0.1.0/agentic_base/evaluation/vector_stores.py +149 -0
  31. agentic_base-0.1.0/agentic_base/ingestion/__init__.py +52 -0
  32. agentic_base-0.1.0/agentic_base/ingestion/settings.py +9 -0
  33. agentic_base-0.1.0/agentic_base/mixins.py +87 -0
  34. agentic_base-0.1.0/agentic_base/py.typed +0 -0
  35. agentic_base-0.1.0/agentic_base/settings.py +41 -0
  36. agentic_base-0.1.0/agentic_base/types.py +48 -0
  37. agentic_base-0.1.0/agentic_base/utils.py +39 -0
  38. agentic_base-0.1.0/pyproject.toml +43 -0
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.4
2
+ Name: agentic-base
3
+ Version: 0.1.0
4
+ Summary:
5
+ Author: jalal
6
+ Author-email: jalalkhaldi3@gmail.com
7
+ Requires-Python: >=3.11,<3.13
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Requires-Dist: faiss-cpu (>=1.13.2,<2.0.0)
12
+ Requires-Dist: langchain (>=1.2.10,<2.0.0)
13
+ Requires-Dist: minio (>=7.2.20,<8.0.0)
14
+ Requires-Dist: numpy (>=2.4.2,<3.0.0)
15
+ Requires-Dist: openai (>=2.21.0,<3.0.0)
16
+ Requires-Dist: polars (>=1.38.1,<2.0.0)
17
+ Requires-Dist: pydantic-settings (>=2.13.0,<3.0.0)
18
+ Requires-Dist: qdrant-client (>=1.16.2,<2.0.0)
19
+ Requires-Dist: rank-bm25 (>=0.2.2,<0.3.0)
20
+ Requires-Dist: transformers (>=5.2.0,<6.0.0)
21
+ Description-Content-Type: text/markdown
22
+
23
+
File without changes
File without changes
@@ -0,0 +1,71 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Any, Generic, Union
4
+
5
+ import polars as pl
6
+
7
+ from agentic_base.mixins import FromConfigMixin
8
+ from agentic_base.types import TCDatasetConnector
9
+
10
+ if TYPE_CHECKING:
11
+ from agentic_base.dataset import Dataset, TextDataset
12
+
13
+ _logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DatasetConnector(
17
+ FromConfigMixin[TCDatasetConnector],
18
+ ABC,
19
+ Generic[TCDatasetConnector],
20
+ ):
21
+ def __init__(self, config: TCDatasetConnector):
22
+ super().__init__(config)
23
+
24
+ _logger.info(
25
+ f"Initializing dataset connector | "
26
+ f"class={self.__class__.__name__} | "
27
+ f"module={self.__class__.__module__}"
28
+ )
29
+
30
+ @abstractmethod
31
+ def _load(self) -> Union[pl.DataFrame, pl.LazyFrame]:
32
+ """
33
+ Load raw data as Polars DataFrame or LazyFrame.
34
+ """
35
+ raise NotImplementedError()
36
+
37
+ @abstractmethod
38
+ def to(self, ds: "Dataset[Any]") -> None:
39
+ raise NotImplementedError()
40
+
41
+ def load(self) -> "Dataset[Union[pl.DataFrame, pl.LazyFrame]]":
42
+ from agentic_base.dataset.polars import PolarsDataset
43
+
44
+ _logger.info(f"Loading dataset | connector={self.__class__.__name__}")
45
+
46
+ df = self._load()
47
+ self._log_polars_info(df)
48
+
49
+ return PolarsDataset.from_polars(df)
50
+
51
+ def load_text(self) -> "TextDataset[Union[pl.DataFrame, pl.LazyFrame]]":
52
+ from agentic_base.dataset.polars import PolarsTextDataset
53
+
54
+ _logger.info(f"Loading text dataset | connector={self.__class__.__name__}")
55
+
56
+ df = self._load()
57
+ self._log_polars_info(df)
58
+
59
+ return PolarsTextDataset.from_polars(df)
60
+
61
+ # ------------------------------------------------------------------
62
+ # Helpers
63
+ # ------------------------------------------------------------------
64
+
65
+ def _log_polars_info(self, df: Union[pl.DataFrame, pl.LazyFrame]) -> None:
66
+ """
67
+ Log dataset structure without forcing materialization.
68
+ """
69
+ if isinstance(df, pl.DataFrame):
70
+ schema = df.schema
71
+ _logger.info(f"Loaded DataFrame | " f"columns={len(schema)} | " f"schema={list(schema.items())}")
@@ -0,0 +1,46 @@
1
+ import io
2
+ from typing import TYPE_CHECKING, Any, Union
3
+
4
+ import polars as pl
5
+ from minio import Minio
6
+
7
+ from agentic_base.connector import DatasetConnector
8
+ from agentic_base.connector.settings import MinioDatasetConnectorSettings
9
+
10
+ if TYPE_CHECKING:
11
+ from agentic_base.dataset import Dataset
12
+
13
+
14
+ class MinioDatasetConnector(DatasetConnector[MinioDatasetConnectorSettings]):
15
+
16
+ def __init__(self, config: MinioDatasetConnectorSettings):
17
+ super().__init__(config)
18
+ self.client = Minio(
19
+ self.config.endpoint.replace("http://", "").replace("https://", ""),
20
+ access_key=self.config.access_key.get_secret_value(),
21
+ secret_key=self.config.secret_key.get_secret_value(),
22
+ secure=self.config.endpoint.startswith("https://"),
23
+ )
24
+
25
+ def _load(self) -> Union[pl.DataFrame, pl.LazyFrame]:
26
+ response = self.client.get_object(self.config.bucket, self.config.key)
27
+ try:
28
+ buffer = io.BytesIO(response.read())
29
+ finally:
30
+ response.close()
31
+ response.release_conn()
32
+ df = pl.read_parquet(buffer)
33
+ return df
34
+
35
+ def to(self, ds: "Dataset[Any]") -> None:
36
+ df = ds.polars
37
+ buffer = io.BytesIO()
38
+ df.write_parquet(buffer)
39
+ buffer.seek(0)
40
+ self.client.put_object(
41
+ bucket_name=self.config.bucket,
42
+ object_name=self.config.key,
43
+ data=buffer,
44
+ length=buffer.getbuffer().nbytes,
45
+ content_type="application/octet-stream",
46
+ )
@@ -0,0 +1,21 @@
1
+ from typing import TYPE_CHECKING, Any, Union
2
+
3
+ import polars as pl
4
+
5
+ from agentic_base.connector import DatasetConnector
6
+ from agentic_base.connector.settings import ParquetDatasetConnectorSettings
7
+
8
+ if TYPE_CHECKING:
9
+ from agentic_base.dataset import Dataset
10
+
11
+
12
+ class ParquetDatasetConnector(DatasetConnector[ParquetDatasetConnectorSettings]):
13
+
14
+ def __init__(self, config: ParquetDatasetConnectorSettings):
15
+ super().__init__(config)
16
+
17
+ def _load(self) -> Union[pl.DataFrame, pl.LazyFrame]:
18
+ return pl.scan_parquet(self.config.path) if self.config.lazy else pl.read_parquet(self.config.path)
19
+
20
+ def to(self, ds: "Dataset[Any]") -> None:
21
+ ds.polars.write_parquet(self.config.path)
@@ -0,0 +1,22 @@
1
+ from pydantic import SecretStr
2
+ from pydantic_settings import SettingsConfigDict
3
+
4
+ from agentic_base.settings import FromConfigMixinSettings
5
+
6
+
7
+ class DatasetConnectorSettings(FromConfigMixinSettings):
8
+ pass
9
+
10
+
11
+ class ParquetDatasetConnectorSettings(DatasetConnectorSettings):
12
+ path: str
13
+ lazy: bool
14
+
15
+
16
+ class MinioDatasetConnectorSettings(DatasetConnectorSettings):
17
+ endpoint: str
18
+ bucket: str
19
+ key: str
20
+ access_key: SecretStr
21
+ secret_key: SecretStr
22
+ model_config = SettingsConfigDict(env_prefix='MINIO_', extra="ignore")
@@ -0,0 +1 @@
1
+ CONFIG_PATH = "/config/config.yaml"
@@ -0,0 +1,155 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Dict,
7
+ Generic,
8
+ Iterable,
9
+ Iterator,
10
+ Literal,
11
+ Optional,
12
+ Self,
13
+ Tuple,
14
+ Union,
15
+ overload,
16
+ )
17
+
18
+ import polars as pl
19
+ from polars._typing import ColumnNameOrSelector, IntoExpr, IntoExprColumn
20
+
21
+ from agentic_base.dataset.mixins import TextDatasetMixin
22
+ from agentic_base.types import TDataset
23
+ from agentic_base.utils import _get_minio_connector, _get_parquet_connector
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ _logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Dataset(ABC, Generic[TDataset]):
32
+ def __init__(self, service: TDataset):
33
+ super().__init__()
34
+ self.service = service
35
+ self._polars: Optional[pl.DataFrame] = None
36
+ self._lazy_polars: Optional[pl.LazyFrame] = None
37
+
38
+ @classmethod
39
+ @abstractmethod
40
+ def from_polars(cls, df: Union[pl.DataFrame, pl.LazyFrame]) -> Self:
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def to_polars(self) -> pl.DataFrame:
45
+ raise NotImplementedError
46
+
47
+ @abstractmethod
48
+ def to_lazy_polars(self) -> pl.LazyFrame:
49
+ raise NotImplementedError
50
+
51
+ @classmethod
52
+ def from_parquet(
53
+ cls,
54
+ path: str,
55
+ *,
56
+ lazy: bool = True,
57
+ ) -> "Dataset[Union[pl.DataFrame, pl.LazyFrame]]":
58
+ return _get_parquet_connector(path, lazy=lazy).load()
59
+
60
+ def to_parquet(self, path: str, *, lazy: bool = True) -> None:
61
+ return _get_parquet_connector(path, lazy=lazy).to(self)
62
+
63
+ @classmethod
64
+ def from_minio(
65
+ cls,
66
+ bucket: str,
67
+ key: str,
68
+ endpoint: str,
69
+ access_key: str,
70
+ secret_key: str,
71
+ ) -> "Dataset[Union[pl.DataFrame, pl.LazyFrame]]":
72
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load()
73
+
74
+ def to_minio(self, bucket: str, key: str, endpoint: str, access_key: str, secret_key: str) -> None:
75
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).to(self)
76
+
77
+ @property
78
+ def polars(self) -> pl.DataFrame:
79
+ if self._polars is None:
80
+ self._polars = self.to_polars()
81
+ return self._polars
82
+
83
+ @property
84
+ def lazy_polars(self) -> pl.LazyFrame:
85
+ if self._lazy_polars is None:
86
+ self._lazy_polars = self.to_lazy_polars()
87
+ return self._lazy_polars
88
+
89
+ @property
90
+ def polars_shape(self) -> Tuple[int, int]:
91
+ return self.polars.shape
92
+
93
+ def with_columns(
94
+ self,
95
+ *exprs: IntoExpr | Iterable[IntoExpr],
96
+ **named_exprs: IntoExpr,
97
+ ) -> Self:
98
+ df = self.polars.with_columns(*exprs, **named_exprs)
99
+ return self.__class__.from_polars(df)
100
+
101
+ def filter(
102
+ self,
103
+ *predicates: (IntoExprColumn | Iterable[IntoExprColumn] | bool | list[bool]),
104
+ **constraints: Any,
105
+ ) -> Self:
106
+ return self.__class__.from_polars(self.polars.filter(*predicates, **constraints))
107
+
108
+ def drop(
109
+ self,
110
+ *columns: ColumnNameOrSelector | Iterable[ColumnNameOrSelector],
111
+ strict: bool = True,
112
+ ) -> Self:
113
+ return self.__class__.from_polars(self.polars.drop(*columns, strict=strict))
114
+
115
+ def __len__(self) -> int:
116
+ return self.polars_shape[0]
117
+
118
+ @overload
119
+ def iter_rows(self, *, named: Literal[False] = ..., buffer_size: int = ...) -> Iterator[Tuple[Any, ...]]: ...
120
+
121
+ @overload
122
+ def iter_rows(self, *, named: Literal[True], buffer_size: int = ...) -> Iterator[Dict[str, Any]]: ...
123
+
124
+ def iter_rows(
125
+ self, *, named: Literal[False, True] = False, buffer_size: int = 512
126
+ ) -> Iterator[Tuple[Any, ...]] | Iterator[Dict[str, Any]]:
127
+ df_stream = self.lazy_polars.collect(streaming=True) # type: ignore[call-overload]
128
+ return df_stream.iter_rows(named=named, buffer_size=buffer_size) # type: ignore[no-any-return]
129
+
130
+
131
+ class TextDataset(Dataset[TDataset], TextDatasetMixin[TDataset], Generic[TDataset]):
132
+
133
+ def __init__(self, service: TDataset):
134
+ super().__init__(service)
135
+ self._validate_schema()
136
+
137
+ @classmethod
138
+ def from_parquet(
139
+ cls,
140
+ path: str,
141
+ *,
142
+ lazy: bool = True,
143
+ ) -> "TextDataset[Union[pl.DataFrame, pl.LazyFrame]]":
144
+ return _get_parquet_connector(path, lazy=lazy).load_text()
145
+
146
+ @classmethod
147
+ def from_minio(
148
+ cls,
149
+ bucket: str,
150
+ key: str,
151
+ endpoint: str,
152
+ access_key: str,
153
+ secret_key: str,
154
+ ) -> "TextDataset[Union[pl.DataFrame, pl.LazyFrame]]":
155
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load_text()
@@ -0,0 +1,114 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Dict,
7
+ Generic,
8
+ Iterator,
9
+ List,
10
+ Protocol,
11
+ Tuple,
12
+ Union,
13
+ runtime_checkable,
14
+ )
15
+
16
+ import polars as pl
17
+ from langchain_core.documents import Document
18
+
19
+ from agentic_base.types import TDataset
20
+
21
+ if TYPE_CHECKING:
22
+
23
+ from agentic_base.dataset import TextDataset
24
+
25
+
26
+ @runtime_checkable
27
+ class SupportsPolarsDataset(Protocol[TDataset]):
28
+ @property
29
+ def polars(self) -> pl.DataFrame: ...
30
+
31
+ def to_polars(self) -> pl.DataFrame: ...
32
+
33
+ @classmethod
34
+ def from_polars(cls, df: Union[pl.DataFrame, pl.LazyFrame]) -> "TextDataset[TDataset]":
35
+ raise NotImplementedError
36
+
37
+
38
+ class TextDatasetMixin(SupportsPolarsDataset[TDataset], Generic[TDataset]):
39
+ REQUIRED_COLUMNS = {"page_content", "metadata"}
40
+
41
+ def _validate_schema(self) -> None:
42
+ df = self.polars
43
+ missing = self.REQUIRED_COLUMNS - set(df.columns)
44
+ if missing:
45
+ raise ValueError(
46
+ f"{self.__class__.__name__} requires columns {self.REQUIRED_COLUMNS}, " f"but missing {missing}"
47
+ )
48
+
49
+ def iter_documents(self) -> Iterator[Document]:
50
+ df = self.to_polars()
51
+ for row in df.iter_rows(named=True):
52
+ yield Document(
53
+ page_content=row["page_content"],
54
+ metadata=row["metadata"],
55
+ )
56
+
57
+ @classmethod
58
+ def from_records(
59
+ cls,
60
+ records: List[Tuple[str, Dict[str, Any]]],
61
+ ) -> "TextDataset[TDataset]":
62
+ if not records:
63
+ raise ValueError("records must be non-empty")
64
+ df = pl.DataFrame(
65
+ {
66
+ "page_content": [text for text, _ in records],
67
+ "metadata": [meta for _, meta in records],
68
+ }
69
+ )
70
+ return cls.from_polars(df)
71
+
72
+ def dump_documents(
73
+ self,
74
+ out_dir: str,
75
+ prefix: str = "doc",
76
+ ext: str = ".md",
77
+ encoding: str = "utf-8",
78
+ ) -> None:
79
+ out_dir_path = Path(out_dir)
80
+ out_dir_path.mkdir(parents=True, exist_ok=True)
81
+
82
+ for i, row in enumerate(self.polars.iter_rows(named=True)):
83
+ path = out_dir_path / f"{prefix}_{i:05d}{ext}"
84
+ page_content = row.get("page_content", "")
85
+ metadata = row.get("metadata", {})
86
+ try:
87
+ metadata_str = json.dumps(metadata, indent=2, ensure_ascii=False)
88
+ except TypeError:
89
+ metadata_str = str(metadata)
90
+ content = "###PAGE_CONTENT###\n" f"{page_content}\n\n" "###METADATA###\n" f"{metadata_str}\n"
91
+ path.write_text(content, encoding=encoding)
92
+
93
+ def to_langchain_documents(
94
+ self,
95
+ ) -> List[Document]:
96
+ docs: list[Document] = []
97
+ for row in self.polars.iter_rows(named=True):
98
+ docs.append(
99
+ Document(
100
+ page_content=row["page_content"],
101
+ metadata=row["metadata"],
102
+ )
103
+ )
104
+ return docs
105
+
106
+ @classmethod
107
+ def from_documents(cls, docs: List[Document]) -> "TextDataset[TDataset]":
108
+ df = pl.DataFrame(
109
+ {
110
+ "page_content": [d.page_content for d in docs],
111
+ "metadata": [d.metadata or {} for d in docs],
112
+ }
113
+ )
114
+ return cls.from_polars(df)
@@ -0,0 +1,43 @@
1
+ from typing import Self, Union
2
+
3
+ import polars as pl
4
+
5
+ from agentic_base.dataset import Dataset, TextDataset
6
+
7
+
8
+ class PolarsDataset(Dataset[Union[pl.DataFrame, pl.LazyFrame]]):
9
+ def __init__(self, service: Union[pl.DataFrame, pl.LazyFrame]):
10
+ super().__init__(service)
11
+
12
+ @classmethod
13
+ def from_polars(cls, df: Union[pl.DataFrame, pl.LazyFrame]) -> Self:
14
+ return cls(df)
15
+
16
+ def to_polars(self) -> pl.DataFrame:
17
+ if isinstance(self.service, pl.LazyFrame):
18
+ return self.service.collect()
19
+ return self.service
20
+
21
+ def to_lazy_polars(self) -> pl.LazyFrame:
22
+ if isinstance(self.service, pl.LazyFrame):
23
+ return self.service
24
+ return self.service.lazy()
25
+
26
+
27
+ class PolarsTextDataset(TextDataset[Union[pl.DataFrame, pl.LazyFrame]]):
28
+ def __init__(self, service: Union[pl.DataFrame, pl.LazyFrame]):
29
+ super().__init__(service)
30
+
31
+ @classmethod
32
+ def from_polars(cls, df: Union[pl.DataFrame, pl.LazyFrame]) -> Self:
33
+ return cls(df)
34
+
35
+ def to_polars(self) -> pl.DataFrame:
36
+ if isinstance(self.service, pl.LazyFrame):
37
+ return self.service.collect()
38
+ return self.service
39
+
40
+ def to_lazy_polars(self) -> pl.LazyFrame:
41
+ if isinstance(self.service, pl.LazyFrame):
42
+ return self.service
43
+ return self.service.lazy()
@@ -0,0 +1,28 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Generic
3
+
4
+ from agentic_base.dataset.polars import TextDataset
5
+ from agentic_base.mixins import FromConfigMixin
6
+ from agentic_base.types import TCTextPreprocessor, TCTokenCounter
7
+
8
+
9
+ class TextPreprocessor(FromConfigMixin[TCTextPreprocessor], ABC, Generic[TCTextPreprocessor]):
10
+
11
+ def __init__(self, config: TCTextPreprocessor) -> None:
12
+ self.config = config
13
+
14
+ @abstractmethod
15
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
16
+ raise NotImplementedError
17
+
18
+
19
+ class TokenCounter(FromConfigMixin[TCTokenCounter], ABC, Generic[TCTokenCounter]):
20
+ def __init__(self, config: TCTokenCounter) -> None:
21
+ self.config = config
22
+
23
+ @abstractmethod
24
+ def count(self, text: str) -> int:
25
+ raise NotImplementedError()
26
+
27
+ def count_batch(self, texts: list[str]) -> list[int]:
28
+ return [self.count(t) for t in texts]
@@ -0,0 +1,98 @@
1
+ from typing import Any, List
2
+
3
+ import polars as pl
4
+
5
+ from agentic_base.dataset import TextDataset
6
+ from agentic_base.dataset.preprocess import TextPreprocessor, TokenCounter
7
+ from agentic_base.dataset.settings import (
8
+ MaxTokenFilterSettings,
9
+ MinTokenFilterSettings,
10
+ PreprocessPipelineSettings,
11
+ QuantileTokenFilterSettings,
12
+ SigmaBandTokenFilterSettings,
13
+ )
14
+ from agentic_base.utils import load_class
15
+
16
+
17
+ class MinTokenFilter(TextPreprocessor[MinTokenFilterSettings[Any]]):
18
+
19
+ def __init__(self, config: MinTokenFilterSettings[Any]):
20
+ super().__init__(config)
21
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
22
+ self.config.token_counter
23
+ )
24
+
25
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
26
+ return ds.__class__.from_polars(
27
+ ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
28
+ .filter(pl.col("n_tokens") >= self.config.min_tokens)
29
+ .drop("n_tokens")
30
+ )
31
+
32
+
33
+ class MaxTokenFilter(TextPreprocessor[MaxTokenFilterSettings[Any]]):
34
+ def __init__(self, config: MaxTokenFilterSettings[Any]):
35
+ super().__init__(config)
36
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
37
+ self.config.token_counter
38
+ )
39
+
40
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
41
+ return ds.__class__.from_polars(
42
+ ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
43
+ .filter(pl.col("n_tokens") <= self.config.max_tokens)
44
+ .drop("n_tokens")
45
+ )
46
+
47
+
48
+ class QuantileTokenFilter(TextPreprocessor[QuantileTokenFilterSettings[Any]]):
49
+ def __init__(self, config: QuantileTokenFilterSettings[Any]):
50
+ super().__init__(config)
51
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
52
+ self.config.token_counter
53
+ )
54
+
55
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
56
+ df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
57
+ cutoff = df.select(pl.col("n_tokens").quantile(self.config.q)).item()
58
+ return ds.__class__.from_polars(df.filter(pl.col("n_tokens") <= cutoff).drop("n_tokens"))
59
+
60
+
61
+ class SigmaBandTokenFilter(TextPreprocessor[SigmaBandTokenFilterSettings[Any]]):
62
+ def __init__(self, config: SigmaBandTokenFilterSettings[Any]):
63
+ super().__init__(config)
64
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
65
+ self.config.token_counter
66
+ )
67
+
68
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
69
+ df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
70
+ mu, sigma = df.select(
71
+ pl.col("n_tokens").mean(),
72
+ pl.col("n_tokens").std(),
73
+ ).row(0)
74
+ return ds.__class__.from_polars(
75
+ df.filter(pl.col("n_tokens").is_between(mu - self.config.z * sigma, mu + self.config.z * sigma)).drop(
76
+ "n_tokens"
77
+ )
78
+ )
79
+
80
+
81
+ class PreprocessPipeline(TextPreprocessor[PreprocessPipelineSettings]):
82
+
83
+ def __init__(self, config: PreprocessPipelineSettings):
84
+ super().__init__(config)
85
+ self.steps = self._load_steps()
86
+
87
+ def _load_steps(self) -> List[TextPreprocessor[Any]]:
88
+ steps = []
89
+ for step_config in self.config.steps:
90
+ step: TextPreprocessor[Any] = load_class(step_config.module_path)(step_config)
91
+ steps.append(step)
92
+ return steps
93
+
94
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
95
+ current = ds
96
+ for step in self.steps:
97
+ current = step.apply(current)
98
+ return current
@@ -0,0 +1,37 @@
1
+ from transformers import AutoTokenizer
2
+
3
+ from agentic_base.dataset.preprocess import TokenCounter
4
+ from agentic_base.dataset.settings import (
5
+ HeuristicTokenCounterSettings,
6
+ HuggingFaceTokenCounterSettings,
7
+ )
8
+
9
+
10
+ class HeuristicTokenCounter(TokenCounter[HeuristicTokenCounterSettings]):
11
+
12
+ def __init__(self, config: HeuristicTokenCounterSettings) -> None:
13
+ super().__init__(config)
14
+
15
+ def count(self, text: str) -> int:
16
+ return max(1, int(len(text) / self.config.chars_per_token))
17
+
18
+
19
+ class HuggingFaceTokenCounter(TokenCounter[HuggingFaceTokenCounterSettings]):
20
+ def __init__(self, config: HuggingFaceTokenCounterSettings) -> None:
21
+ super().__init__(config)
22
+ self.tokenizer = AutoTokenizer.from_pretrained(
23
+ self.config.name,
24
+ revision=self.config.revision,
25
+ trust_remote_code=True,
26
+ )
27
+
28
+ def count(self, text: str) -> int:
29
+ return len(
30
+ self.tokenizer(
31
+ text,
32
+ add_special_tokens=self.config.add_special_tokens,
33
+ truncation=False,
34
+ return_attention_mask=False,
35
+ return_token_type_ids=False,
36
+ )["input_ids"]
37
+ )