dataeval 0.84.1__py3-none-any.whl → 0.86.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/{utils/data → data}/_embeddings.py +137 -17
  4. dataeval/{utils/data → data}/_metadata.py +20 -8
  5. dataeval/{utils/data → data}/_selection.py +22 -9
  6. dataeval/{utils/data → data}/_split.py +1 -1
  7. dataeval/data/selections/__init__.py +19 -0
  8. dataeval/{utils/data → data}/selections/_classbalance.py +1 -2
  9. dataeval/data/selections/_classfilter.py +110 -0
  10. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  11. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  12. dataeval/{utils/data → data}/selections/_prioritize.py +2 -2
  13. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  14. dataeval/{utils/data → data}/selections/_shuffle.py +1 -1
  15. dataeval/detectors/drift/__init__.py +4 -1
  16. dataeval/detectors/drift/_base.py +1 -1
  17. dataeval/detectors/drift/_cvm.py +2 -2
  18. dataeval/detectors/drift/_ks.py +2 -2
  19. dataeval/detectors/drift/_mmd.py +2 -2
  20. dataeval/detectors/drift/_mvdc.py +92 -0
  21. dataeval/detectors/drift/_nml/__init__.py +6 -0
  22. dataeval/detectors/drift/_nml/_base.py +68 -0
  23. dataeval/detectors/drift/_nml/_chunk.py +404 -0
  24. dataeval/detectors/drift/_nml/_domainclassifier.py +192 -0
  25. dataeval/detectors/drift/_nml/_result.py +98 -0
  26. dataeval/detectors/drift/_nml/_thresholds.py +280 -0
  27. dataeval/detectors/linters/duplicates.py +1 -1
  28. dataeval/detectors/linters/outliers.py +1 -1
  29. dataeval/metadata/_distance.py +1 -1
  30. dataeval/metadata/_ood.py +4 -4
  31. dataeval/metrics/bias/_balance.py +1 -1
  32. dataeval/metrics/bias/_diversity.py +1 -1
  33. dataeval/metrics/bias/_parity.py +1 -1
  34. dataeval/metrics/stats/_labelstats.py +2 -2
  35. dataeval/outputs/__init__.py +2 -1
  36. dataeval/outputs/_bias.py +2 -4
  37. dataeval/outputs/_drift.py +68 -0
  38. dataeval/outputs/_linters.py +1 -6
  39. dataeval/outputs/_stats.py +1 -6
  40. dataeval/typing.py +31 -0
  41. dataeval/utils/__init__.py +2 -2
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/collate.py +2 -0
  44. dataeval/utils/datasets/__init__.py +17 -0
  45. dataeval/utils/{data/datasets → datasets}/_base.py +3 -3
  46. dataeval/utils/{data/datasets → datasets}/_cifar10.py +2 -2
  47. dataeval/utils/{data/datasets → datasets}/_milco.py +2 -2
  48. dataeval/utils/{data/datasets → datasets}/_mnist.py +2 -2
  49. dataeval/utils/{data/datasets → datasets}/_ships.py +2 -2
  50. dataeval/utils/{data/datasets → datasets}/_voc.py +3 -3
  51. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/METADATA +3 -2
  52. dataeval-0.86.0.dist-info/RECORD +114 -0
  53. dataeval/utils/data/datasets/__init__.py +0 -17
  54. dataeval/utils/data/selections/__init__.py +0 -19
  55. dataeval/utils/data/selections/_classfilter.py +0 -44
  56. dataeval-0.84.1.dist-info/RECORD +0 -106
  57. /dataeval/{utils/data → data}/_images.py +0 -0
  58. /dataeval/{utils/data → data}/_targets.py +0 -0
  59. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  60. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  61. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  62. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  63. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/LICENSE.txt +0 -0
  64. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/WHEEL +0 -0
@@ -16,9 +16,9 @@ import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
+ from dataeval.data._embeddings import Embeddings
19
20
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
21
  from dataeval.typing import Array
21
- from dataeval.utils.data._embeddings import Embeddings
22
22
 
23
23
 
24
24
  class DriftCVM(BaseDriftUnivariate):
@@ -52,7 +52,7 @@ class DriftCVM(BaseDriftUnivariate):
52
52
 
53
53
  Example
54
54
  -------
55
- >>> from dataeval.utils.data import Embeddings
55
+ >>> from dataeval.data import Embeddings
56
56
 
57
57
  Use Embeddings to encode images before testing for drift
58
58
 
@@ -16,9 +16,9 @@ import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
+ from dataeval.data._embeddings import Embeddings
19
20
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
21
  from dataeval.typing import Array
21
- from dataeval.utils.data._embeddings import Embeddings
22
22
 
23
23
 
24
24
  class DriftKS(BaseDriftUnivariate):
@@ -54,7 +54,7 @@ class DriftKS(BaseDriftUnivariate):
54
54
 
55
55
  Example
56
56
  -------
57
- >>> from dataeval.utils.data import Embeddings
57
+ >>> from dataeval.data import Embeddings
58
58
 
59
59
  Use Embeddings to encode images before testing for drift
60
60
 
@@ -15,11 +15,11 @@ from typing import Any, Callable
15
15
  import torch
16
16
 
17
17
  from dataeval.config import DeviceLike, get_device
18
+ from dataeval.data._embeddings import Embeddings
18
19
  from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, update_strategy
19
20
  from dataeval.outputs import DriftMMDOutput
20
21
  from dataeval.outputs._base import set_metadata
21
22
  from dataeval.typing import Array
22
- from dataeval.utils.data._embeddings import Embeddings
23
23
 
24
24
 
25
25
  class DriftMMD(BaseDrift):
@@ -51,7 +51,7 @@ class DriftMMD(BaseDrift):
51
51
 
52
52
  Example
53
53
  -------
54
- >>> from dataeval.utils.data import Embeddings
54
+ >>> from dataeval.data import Embeddings
55
55
 
56
56
  Use Embeddings to encode images before testing for drift
57
57
 
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from numpy.typing import ArrayLike
8
+
9
+ if TYPE_CHECKING:
10
+ from typing import Self
11
+ else:
12
+ from typing_extensions import Self
13
+
14
+ from dataeval.detectors.drift._nml._chunk import CountBasedChunker, SizeBasedChunker
15
+ from dataeval.detectors.drift._nml._domainclassifier import DomainClassifierCalculator
16
+ from dataeval.detectors.drift._nml._thresholds import ConstantThreshold
17
+ from dataeval.outputs._drift import DriftMVDCOutput
18
+ from dataeval.utils._array import flatten
19
+
20
+
21
+ class DriftMVDC:
22
+ """Multivariant Domain Classifier
23
+
24
+ Parameters
25
+ ----------
26
+ n_folds : int, default 5
27
+ Number of cross-validation (CV) folds.
28
+ chunk_size : int or None, default None
29
+ Number of samples in a chunk used in CV, will get one metric & prediction per chunk.
30
+ chunk_count : int or None, default None
31
+ Number of total chunks used in CV, will get one metric & prediction per chunk.
32
+ threshold : Tuple[float, float], default (0.45, 0.65)
33
+ (lower, upper) metric bounds on roc_auc for identifying :term:`drift<Drift>`.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ n_folds: int = 5,
39
+ chunk_size: int | None = None,
40
+ chunk_count: int | None = None,
41
+ threshold: tuple[float, float] = (0.45, 0.65),
42
+ ) -> None:
43
+ self.threshold: tuple[float, float] = max(0.0, min(threshold)), min(1.0, max(threshold))
44
+ chunker = (
45
+ CountBasedChunker(10 if chunk_count is None else chunk_count)
46
+ if chunk_size is None
47
+ else SizeBasedChunker(chunk_size)
48
+ )
49
+ self._calc = DomainClassifierCalculator(
50
+ cv_folds_num=n_folds,
51
+ chunker=chunker,
52
+ threshold=ConstantThreshold(lower=self.threshold[0], upper=self.threshold[1]),
53
+ )
54
+
55
+ def fit(self, x_ref: ArrayLike) -> Self:
56
+ """
57
+ Fit the domain classifier on the training dataframe
58
+
59
+ Parameters
60
+ ----------
61
+ x_ref : ArrayLike
62
+ Reference data with dim[n_samples, n_features].
63
+
64
+ Returns
65
+ -------
66
+ Self
67
+
68
+ """
69
+ # for 1D input, assume that is 1 sample: dim[1,n_features]
70
+ self.x_ref: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x_ref))))
71
+ self.n_features: int = self.x_ref.shape[-1]
72
+ self._calc.fit(self.x_ref)
73
+ return self
74
+
75
+ def predict(self, x: ArrayLike) -> DriftMVDCOutput:
76
+ """
77
+ Perform :term:`inference<Inference>` on the test dataframe
78
+
79
+ Parameters
80
+ ----------
81
+ x : ArrayLike
82
+ Test (analysis) data with dim[n_samples, n_features].
83
+
84
+ Returns
85
+ -------
86
+ DomainClassifierDriftResult
87
+ """
88
+ self.x_test: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x))))
89
+ if self.x_test.shape[-1] != self.n_features:
90
+ raise ValueError("Reference and test embeddings have different number of features")
91
+
92
+ return self._calc.calculate(self.x_test)
@@ -0,0 +1,6 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """
@@ -0,0 +1,68 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/blob/main/nannyml/base.py
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from logging import Logger
13
+ from typing import Sequence
14
+
15
+ import pandas as pd
16
+ from typing_extensions import Self
17
+
18
+ from dataeval.detectors.drift._nml._chunk import Chunk, Chunker, CountBasedChunker
19
+ from dataeval.outputs._drift import DriftMVDCOutput
20
+
21
+
22
+ def _validate(data: pd.DataFrame, expected_features: int | None = None) -> int:
23
+ if data.empty:
24
+ raise ValueError("data contains no rows. Please provide a valid data set.")
25
+ if expected_features is not None and data.shape[-1] != expected_features:
26
+ raise ValueError(f"expected '{expected_features}' features in data set:\n\t{data}")
27
+ return data.shape[-1]
28
+
29
+
30
+ def _create_multilevel_index(chunks: Sequence[Chunk], result_group_name: str, result_column_names: Sequence[str]):
31
+ chunk_column_names = (*chunks[0].KEYS, "period")
32
+ chunk_tuples = [("chunk", chunk_column_name) for chunk_column_name in chunk_column_names]
33
+ result_tuples = [(result_group_name, column_name) for column_name in result_column_names]
34
+ return pd.MultiIndex.from_tuples(chunk_tuples + result_tuples)
35
+
36
+
37
+ class AbstractCalculator(ABC):
38
+ """Base class for drift calculation."""
39
+
40
+ def __init__(self, chunker: Chunker | None = None, logger: Logger | None = None):
41
+ self.chunker = chunker if isinstance(chunker, Chunker) else CountBasedChunker(10)
42
+ self.result: DriftMVDCOutput | None = None
43
+ self.n_features: int | None = None
44
+ self._logger = logger if isinstance(logger, Logger) else logging.getLogger(__name__)
45
+
46
+ def fit(self, reference_data: pd.DataFrame) -> Self:
47
+ """Trains the calculator using reference data."""
48
+ self.n_features = _validate(reference_data)
49
+
50
+ self._logger.debug(f"fitting {str(self)}")
51
+ self.result = self._fit(reference_data)
52
+ return self
53
+
54
+ def calculate(self, data: pd.DataFrame) -> DriftMVDCOutput:
55
+ """Performs a calculation on the provided data."""
56
+ if self.result is None:
57
+ raise RuntimeError("must run fit with reference data before running calculate")
58
+ _validate(data, self.n_features)
59
+
60
+ self._logger.debug(f"calculating {str(self)}")
61
+ self.result = self._calculate(data)
62
+ return self.result
63
+
64
+ @abstractmethod
65
+ def _fit(self, reference_data: pd.DataFrame) -> DriftMVDCOutput: ...
66
+
67
+ @abstractmethod
68
+ def _calculate(self, data: pd.DataFrame) -> DriftMVDCOutput: ...
@@ -0,0 +1,404 @@
1
+ """
2
+ NannyML module providing intelligent splitting of data into chunks.
3
+
4
+ Source code derived from NannyML 0.13.0
5
+ https://github.com/NannyML/nannyml/blob/main/nannyml/chunk.py
6
+
7
+ Licensed under Apache Software License (Apache 2.0)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import copy
13
+ import logging
14
+ import warnings
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, Generic, Literal, Sequence, TypeVar, cast
17
+
18
+ import pandas as pd
19
+ from dateutil.parser import ParserError
20
+ from pandas import Index, Period
21
+ from typing_extensions import Self
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Chunk(ABC):
27
+ """A subset of data that acts as a logical unit during calculations."""
28
+
29
+ KEYS: Sequence[str]
30
+
31
+ def __init__(
32
+ self,
33
+ data: pd.DataFrame,
34
+ ):
35
+ self.key: str
36
+ self.data = data
37
+
38
+ self.start_index: int = -1
39
+ self.end_index: int = -1
40
+ self.chunk_index: int = -1
41
+
42
+ def __repr__(self):
43
+ attr_str = ", ".join([f"{k}={v}" for k, v in self.dict().items()])
44
+ return f"{self.__class__.__name__}(data=pd.DataFrame(shape={self.data.shape}), {attr_str})"
45
+
46
+ def __len__(self):
47
+ return self.data.shape[0]
48
+
49
+ @abstractmethod
50
+ def __add__(self, other: Self) -> Self: ...
51
+
52
+ @abstractmethod
53
+ def __lt__(self, other: Self) -> bool: ...
54
+
55
+ @abstractmethod
56
+ def dict(self) -> dict[str, Any]: ...
57
+
58
+
59
+ class IndexChunk(Chunk):
60
+ """Creates a new chunk.
61
+
62
+ Parameters
63
+ ----------
64
+ data : DataFrame, required
65
+ The data to be contained within the chunk.
66
+ start_datetime: datetime
67
+ The starting point in time for this chunk.
68
+ end_datetime: datetime
69
+ The end point in time for this chunk.
70
+ """
71
+
72
+ KEYS = ("key", "chunk_index", "start_index", "end_index")
73
+
74
+ def __init__(
75
+ self,
76
+ data: pd.DataFrame,
77
+ start_index: int,
78
+ end_index: int,
79
+ ):
80
+ super().__init__(data)
81
+ self.key = f"[{start_index}:{end_index}]"
82
+ self.start_index: int = start_index
83
+ self.end_index: int = end_index
84
+
85
+ def __lt__(self, other: Self) -> bool:
86
+ return self.end_index < other.start_index
87
+
88
+ def __add__(self, other: Self) -> Self:
89
+ a, b = (self, other) if self < other else (other, self)
90
+ result = copy.deepcopy(a)
91
+ result.data = pd.concat([a.data, b.data])
92
+ result.end_index = b.end_index
93
+ return result
94
+
95
+ def dict(self) -> dict[str, Any]:
96
+ return dict(zip(self.KEYS, (self.key, self.chunk_index, self.start_index, self.end_index)))
97
+
98
+
99
+ class PeriodChunk(Chunk):
100
+ """Creates a new chunk.
101
+
102
+ Parameters
103
+ ----------
104
+ data : DataFrame, required
105
+ The data to be contained within the chunk.
106
+ start_datetime: datetime
107
+ The starting point in time for this chunk.
108
+ end_datetime: datetime
109
+ The end point in time for this chunk.
110
+ chunk_size : int
111
+ The size of the chunk.
112
+ """
113
+
114
+ KEYS = ("key", "chunk_index", "start_date", "end_date", "chunk_size")
115
+
116
+ def __init__(self, data: pd.DataFrame, period: Period, chunk_size: int):
117
+ super().__init__(data)
118
+ self.key = str(period)
119
+ self.start_datetime = period.start_time
120
+ self.end_datetime = period.end_time
121
+ self.chunk_size = chunk_size
122
+
123
+ def __lt__(self, other: Self) -> bool:
124
+ return self.end_datetime < other.start_datetime
125
+
126
+ def __add__(self, other: Self) -> Self:
127
+ a, b = (self, other) if self < other else (other, self)
128
+ result = copy.deepcopy(a)
129
+ result.data = pd.concat([a.data, b.data])
130
+ result.end_datetime = b.end_datetime
131
+ result.chunk_size += b.chunk_size
132
+ return result
133
+
134
+ def dict(self) -> dict[str, Any]:
135
+ return dict(
136
+ zip(self.KEYS, (self.key, self.chunk_index, self.start_datetime, self.end_datetime, self.chunk_size))
137
+ )
138
+
139
+
140
+ TChunk = TypeVar("TChunk", bound=Chunk)
141
+
142
+
143
+ class Chunker(Generic[TChunk]):
144
+ """Base class for Chunker implementations.
145
+
146
+ Inheriting classes will split a DataFrame into a list of Chunks.
147
+ They will do this based on several constraints, e.g. observation timestamps, number of observations per Chunk
148
+ or a preferred number of Chunks.
149
+ """
150
+
151
+ def split(self, data: pd.DataFrame) -> list[TChunk]:
152
+ """Splits a given data frame into a list of chunks.
153
+
154
+ This method provides a uniform interface across Chunker implementations to keep them interchangeable.
155
+
156
+ After performing the implementation-specific `_split` method, there are some checks on the resulting chunk list.
157
+
158
+ If the total number of chunks is low a warning will be written out to the logs.
159
+
160
+ We dynamically determine the optimal minimum number of observations per chunk and then check if the resulting
161
+ chunks contain at least as many. If there are any underpopulated chunks a warning will be written out in
162
+ the logs.
163
+
164
+ Parameters
165
+ ----------
166
+ data: DataFrame
167
+ The data to be split into chunks
168
+
169
+ Returns
170
+ -------
171
+ chunks: List[Chunk]
172
+ The list of chunks
173
+
174
+ """
175
+ if data.shape[0] == 0:
176
+ return []
177
+
178
+ chunks = self._split(data)
179
+ for chunk_index, chunk in enumerate(chunks):
180
+ chunk.start_index = cast(int, chunk.data.index.min())
181
+ chunk.end_index = cast(int, chunk.data.index.max())
182
+ chunk.chunk_index = chunk_index
183
+
184
+ if len(chunks) < 6:
185
+ # TODO wording
186
+ warnings.warn(
187
+ "The resulting number of chunks is too low. "
188
+ "Please consider splitting your data in a different way or continue at your own risk."
189
+ )
190
+
191
+ return chunks
192
+
193
+ @abstractmethod
194
+ def _split(self, data: pd.DataFrame) -> list[TChunk]: ...
195
+
196
+
197
+ class PeriodBasedChunker(Chunker[PeriodChunk]):
198
+ """A Chunker that will split data into Chunks based on a date column in the data.
199
+
200
+ Examples
201
+ --------
202
+ Chunk using monthly periods and providing a column name
203
+
204
+ >>> from nannyml.chunk import PeriodBasedChunker
205
+ >>> df = pd.read_parquet("/path/to/my/data.pq")
206
+ >>> chunker = PeriodBasedChunker(timestamp_column_name="observation_date", offset="M")
207
+ >>> chunks = chunker.split(data=df)
208
+
209
+ Or chunk using weekly periods
210
+
211
+ >>> from nannyml.chunk import PeriodBasedChunker
212
+ >>> df = pd.read_parquet("/path/to/my/data.pq")
213
+ >>> chunker = PeriodBasedChunker(timestamp_column_name=df["observation_date"], offset="W", minimum_chunk_size=50)
214
+ >>> chunks = chunker.split(data=df)
215
+
216
+ """
217
+
218
+ def __init__(self, timestamp_column_name: str, offset: str = "W") -> None:
219
+ """Creates a new PeriodBasedChunker.
220
+
221
+ Parameters
222
+ ----------
223
+ timestamp_column_name : str
224
+ The column name containing the timestamp to chunk on
225
+ offset : str
226
+ A frequency string representing a pandas.tseries.offsets.DateOffset.
227
+ The offset determines how the time-based grouping will occur. A list of possible values
228
+ can be found at <https://pandas.pydata.org/docs/user_guide/timeseries.html#offset-aliases>.
229
+ """
230
+ self.timestamp_column_name = timestamp_column_name
231
+ self.offset = offset
232
+
233
+ def _split(self, data: pd.DataFrame) -> list[PeriodChunk]:
234
+ chunks = []
235
+ if self.timestamp_column_name is None:
236
+ raise ValueError("timestamp_column_name must be provided")
237
+ if self.timestamp_column_name not in data:
238
+ raise ValueError(f"timestamp column '{self.timestamp_column_name}' not in columns")
239
+
240
+ try:
241
+ grouped = data.groupby(pd.to_datetime(data[self.timestamp_column_name]).dt.to_period(self.offset))
242
+ except ParserError:
243
+ raise ValueError(
244
+ f"could not parse date_column '{self.timestamp_column_name}' values as dates."
245
+ f"Please verify if you've specified the correct date column."
246
+ )
247
+
248
+ for k, v in grouped.groups.items():
249
+ period, index = cast(Period, k), cast(Index, v)
250
+ chunk = PeriodChunk(
251
+ data=grouped.get_group(period), # type: ignore | dataframe
252
+ period=period,
253
+ chunk_size=len(index),
254
+ )
255
+ chunks.append(chunk)
256
+
257
+ return chunks
258
+
259
+
260
+ class SizeBasedChunker(Chunker[IndexChunk]):
261
+ """A Chunker that will split data into Chunks based on the preferred number of observations per Chunk.
262
+
263
+ Notes
264
+ -----
265
+ - Chunks are adjacent, not overlapping
266
+ - There may be "incomplete" chunks, as the remainder of observations after dividing by `chunk_size`
267
+ will form a chunk of their own.
268
+
269
+ Examples
270
+ --------
271
+ Chunk using monthly periods and providing a column name
272
+
273
+ >>> from nannyml.chunk import SizeBasedChunker
274
+ >>> df = pd.read_parquet("/path/to/my/data.pq")
275
+ >>> chunker = SizeBasedChunker(chunk_size=2000, incomplete="drop")
276
+ >>> chunks = chunker.split(data=df)
277
+
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ chunk_size: int,
283
+ incomplete: Literal["append", "drop", "keep"] = "keep",
284
+ ):
285
+ """Create a new SizeBasedChunker.
286
+
287
+ Parameters
288
+ ----------
289
+ chunk_size: int
290
+ The preferred size of the resulting Chunks, i.e. the number of observations in each Chunk.
291
+ incomplete: str, default='keep'
292
+ Choose how to handle any leftover observations that don't make up a full Chunk.
293
+ The following options are available:
294
+
295
+ - ``'drop'``: drop the leftover observations
296
+ - ``'keep'``: keep the incomplete Chunk (containing less than ``chunk_size`` observations)
297
+ - ``'append'``: append leftover observations to the last complete Chunk (overfilling it)
298
+
299
+ Defaults to ``'keep'``.
300
+
301
+ Returns
302
+ -------
303
+ chunker: a size-based instance used to split data into Chunks of a constant size.
304
+
305
+ """
306
+ if not isinstance(chunk_size, int) or chunk_size <= 0:
307
+ raise ValueError(f"chunk_size={chunk_size} is invalid - provide an integer greater than 0")
308
+ if incomplete not in ("append", "drop", "keep"):
309
+ raise ValueError(f"incomplete={incomplete} is invalid - must be one of ['append', 'drop', 'keep']")
310
+
311
+ self.chunk_size = chunk_size
312
+ self.incomplete = incomplete
313
+
314
+ def _split(self, data: pd.DataFrame) -> list[IndexChunk]:
315
+ def _create_chunk(index: int, data: pd.DataFrame, chunk_size: int) -> IndexChunk:
316
+ chunk_data = data.iloc[index : index + chunk_size]
317
+ chunk = IndexChunk(
318
+ data=chunk_data,
319
+ start_index=index,
320
+ end_index=index + chunk_size - 1,
321
+ )
322
+ return chunk
323
+
324
+ chunks = [
325
+ _create_chunk(index=i, data=data, chunk_size=self.chunk_size)
326
+ for i in range(0, data.shape[0], self.chunk_size)
327
+ if i + self.chunk_size - 1 < len(data)
328
+ ]
329
+
330
+ # deal with unassigned observations
331
+ if data.shape[0] % self.chunk_size != 0 and self.incomplete != "drop":
332
+ incomplete_chunk = _create_chunk(
333
+ index=self.chunk_size * (data.shape[0] // self.chunk_size),
334
+ data=data,
335
+ chunk_size=(data.shape[0] % self.chunk_size),
336
+ )
337
+ if self.incomplete == "append":
338
+ chunks[-1] += incomplete_chunk
339
+ else:
340
+ chunks += [incomplete_chunk]
341
+
342
+ return chunks
343
+
344
+
345
+ class CountBasedChunker(Chunker[IndexChunk]):
346
+ """A Chunker that will split data into chunks based on the preferred number of total chunks.
347
+
348
+ Notes
349
+ -----
350
+ - Chunks are adjacent, not overlapping
351
+ - There may be "incomplete" chunks, as the remainder of observations after dividing by `chunk_size`
352
+ will form a chunk of their own.
353
+
354
+ Examples
355
+ --------
356
+ >>> from nannyml.chunk import CountBasedChunker
357
+ >>> df = pd.read_parquet("/path/to/my/data.pq")
358
+ >>> chunker = CountBasedChunker(chunk_number=100)
359
+ >>> chunks = chunker.split(data=df)
360
+
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ chunk_number: int,
366
+ incomplete: Literal["append", "drop", "keep"] = "keep",
367
+ ):
368
+ """Creates a new CountBasedChunker.
369
+
370
+ It will calculate the amount of observations per chunk based on the given chunk count.
371
+ It then continues to split the data into chunks just like a SizeBasedChunker does.
372
+
373
+ Parameters
374
+ ----------
375
+ chunk_number: int
376
+ The amount of chunks to split the data in.
377
+ incomplete: str, default='keep'
378
+ Choose how to handle any leftover observations that don't make up a full Chunk.
379
+ The following options are available:
380
+
381
+ - ``'drop'``: drop the leftover observations
382
+ - ``'keep'``: keep the incomplete Chunk (containing less than ``chunk_size`` observations)
383
+ - ``'append'``: append leftover observations to the last complete Chunk (overfilling it)
384
+
385
+ Defaults to ``'keep'``.
386
+
387
+ Returns
388
+ -------
389
+ chunker: CountBasedChunker
390
+
391
+ """
392
+ if not isinstance(chunk_number, int) or chunk_number <= 0:
393
+ raise ValueError(f"given chunk_number {chunk_number} is invalid - provide an integer greater than 0")
394
+ if incomplete not in ("append", "drop", "keep"):
395
+ raise ValueError(f"incomplete={incomplete} is invalid - must be one of ['append', 'drop', 'keep']")
396
+
397
+ self.chunk_number = chunk_number
398
+ self.incomplete: Literal["append", "drop", "keep"] = incomplete
399
+
400
+ def _split(self, data: pd.DataFrame) -> list[IndexChunk]:
401
+ chunk_size = data.shape[0] // self.chunk_number
402
+ chunker = SizeBasedChunker(chunk_size, self.incomplete)
403
+ chunks = chunker.split(data=data)
404
+ return chunks