fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,90 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import TypeVar
4
+ from collections.abc import Callable, Iterator
5
+
6
+ from fkat.data.datasets import SizedDataset
7
+ from torch.utils.data import IterableDataset
8
+
9
+ T_in = TypeVar("T_in", contravariant=True)
10
+ T_from = TypeVar("T_from", covariant=True)
11
+ T_to = TypeVar("T_to", covariant=True)
12
+
13
+
14
+ class MapDataset(SizedDataset[T_in, T_to]):
15
+ """A :class:`Dataset` that transforms the samples from another :class:`Dataset` using a function."""
16
+
17
+ def __init__(
18
+ self,
19
+ dataset: SizedDataset[T_in, T_from],
20
+ fn: Callable[[T_from], T_to],
21
+ ) -> None:
22
+ """Create a :class:`Dataset` that maps samples of another :class:`Dataset` using a function.
23
+
24
+ Args:
25
+ dataset (SizedDataset): Source :class:`Dataset`.
26
+ fn (Callable[[T_from], T_to]): Sample transformation function.
27
+
28
+ Returns:
29
+ None
30
+ """
31
+ self.dataset = dataset
32
+ self.fn = fn
33
+
34
+ def __len__(self) -> int:
35
+ """Get :class:`Dataset` size.
36
+
37
+ Returns:
38
+ int: :class:`Dataset` size.
39
+ """
40
+ return len(self.dataset)
41
+
42
+ def __getitems__(self, idxs: list[T_in]) -> list[T_to]:
43
+ """Get a batch of samples at the specified indices.
44
+
45
+ Args:
46
+ idxs (List[T_in]): Samples' indices.
47
+
48
+ Returns:
49
+ List[T_to]: A batch of samples.
50
+ """
51
+ if getitems := getattr(self.dataset, "__getitems__", None):
52
+ batch = getitems(idxs)
53
+ else:
54
+ batch = [self.dataset[idx] for idx in idxs]
55
+ samples = [self.fn(sample) for sample in batch]
56
+ return samples
57
+
58
+ def __getitem__(self, idx: T_in) -> T_to:
59
+ """Get a sample at the specified index.
60
+
61
+ Args:
62
+ idx (T_in): Sample index.
63
+
64
+ Returns:
65
+ T_to: A sample.
66
+ """
67
+ sample = self.fn(self.dataset[idx])
68
+ return sample
69
+
70
+
71
+ class IterableMapDataset(IterableDataset[T_to]):
72
+ """An :class:`IterableDataset` that transforms the samples from another
73
+ :class:`IterableDataset` using a function."""
74
+
75
+ def __init__(
76
+ self,
77
+ dataset: IterableDataset[T_from],
78
+ fn: Callable[[T_from], T_to],
79
+ ) -> None:
80
+ self.dataset = dataset
81
+ self.fn = fn
82
+
83
+ def __iter__(self) -> Iterator[T_to]:
84
+ """Get :class:`IterableDataset` iterator.
85
+
86
+ Yields:
87
+ T_to: A sample.
88
+ """
89
+ for sample in iter(self.dataset):
90
+ yield self.fn(sample)
@@ -0,0 +1,242 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+ from collections.abc import Iterator
5
+ import awswrangler as s3wr
6
+ from pyarrow.fs import FileSystem, S3FileSystem # type: ignore[possibly-unbound-import]
7
+ import numpy as np
8
+ import pyarrow as pa
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from fkat.data.datasets import SizedDataset
12
+ from fkat.utils.pyarrow import iter_rows as pa_iter_rows
13
+ from fkat.utils.pandas import iter_rows as pd_iter_rows
14
+ from fkat.utils.boto3 import session
15
+
16
+
17
+ class IterableParquetDataset(IterableDataset[dict[str, Any]]):
18
+ """
19
+ An :class:`IterableDataset` backed by Parquet data.
20
+
21
+ .. note:: If you want to keep the original type from reading parquet,
22
+ you should set ``dtype_backend='pyarrow'``.
23
+
24
+ example config:
25
+
26
+ .. code-block:: yaml
27
+
28
+ _target_: fkat.data.datasets.parquet.IterableParquetDataset
29
+ uri: s3://path/to/fkat.parquet
30
+ dtype_backend: pyarrow
31
+
32
+ Args:
33
+ uri (str or list[str]): URI of Parquet data.
34
+ columns (List[str], optional): Columns to load.
35
+ Default to ``None``
36
+ use_threads (bool): Use multi-threaded processing.
37
+ Default to ``True``.
38
+ chunk_size (int): An iterable of DataFrames is returned with maximum rows equal to the received INTEGER.
39
+ Default to ``10000``
40
+ replace_nan (bool): Whether to replace np.nan as None.
41
+ Default to ``True``
42
+ s3wr_args (dict): config for s3wr.s3.read_parquet,
43
+ refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ uri: str | list[str],
49
+ columns: list[str] | None = None,
50
+ use_threads: bool = True,
51
+ chunk_size: int = 10000,
52
+ replace_nan: bool = True,
53
+ **s3wr_args: Any,
54
+ ) -> None:
55
+ fs: FileSystem
56
+ path: str
57
+ if isinstance(uri, str):
58
+ fs, path = FileSystem.from_uri(uri)
59
+ else:
60
+ fs, path = FileSystem.from_uri(uri[0])
61
+ self.s3_file = isinstance(fs, S3FileSystem)
62
+ if self.s3_file:
63
+ self.uri = uri
64
+ self.use_threads = use_threads
65
+ self.columns = columns
66
+ self.chunk_size = chunk_size
67
+ self.replace_nan = replace_nan
68
+ self.s3wr_args = s3wr_args
69
+ else:
70
+ # otherwise, use pyarrow
71
+ path_list = []
72
+ if isinstance(uri, str):
73
+ _, path = FileSystem.from_uri(uri)
74
+ path_list.append(path)
75
+ else:
76
+ path_list = []
77
+ for each in uri:
78
+ _, path = FileSystem.from_uri(each)
79
+ path_list.append(path)
80
+ pds = pa.parquet.ParquetDataset(path_list, filesystem=fs) # type: ignore
81
+ self.tbl = pds.read(columns, use_threads=use_threads, use_pandas_metadata=False)
82
+ self.chunk_size = chunk_size
83
+
84
+ def __iter__(self) -> Iterator[dict[str, Any]]:
85
+ """Creates dataset iterator.
86
+ Returns:
87
+ Iterator[dict[str, Any]]: dataset iterator
88
+ """
89
+ if self.s3_file:
90
+ return pd_iter_rows(
91
+ s3wr.s3.read_parquet(
92
+ self.uri,
93
+ use_threads=self.use_threads,
94
+ columns=self.columns,
95
+ chunked=self.chunk_size,
96
+ boto3_session=session(clients=["s3"]),
97
+ path_suffix="parquet",
98
+ **self.s3wr_args,
99
+ ),
100
+ self.replace_nan,
101
+ )
102
+ else:
103
+ return pa_iter_rows(self.tbl, chunk_size=self.chunk_size)
104
+
105
+
106
+ class ParquetDataset(SizedDataset[int, dict[str, Any]]):
107
+ """
108
+ A :class:`Dataset` backed by Parquet data.
109
+
110
+ Create a :class:`Dataset` from Parquet data at the specified URI.
111
+
112
+ .. note:: If you want to keep the original type from reading parquet,
113
+ you should set ``dtype_backend='pyarrow'``.
114
+
115
+ example config:
116
+
117
+ .. code-block:: yaml
118
+
119
+ _target_: fkat.data.datasets.parquet.ParquetDataset
120
+ uri: s3://path/to/fkat.parquet
121
+ dtype_backend: pyarrow
122
+
123
+ Difference for ``dtype_backend`` between ``pyarrow`` and ``numpy_nullable``
124
+
125
+ .. code-block:: python
126
+
127
+ from fkat.data.datasets.parquet import ParquetDataset
128
+ from fkat.utils.s3_utils import fs_save_prediction_output_parquet
129
+
130
+ uri = "s3://path/to/fkat.parquet"
131
+
132
+ saved_data = {
133
+ "purchased_items": [
134
+ [
135
+ {"product_id": "PROD001", "item_index": 12345, "quantity": "1"},
136
+ {"product_id": "PROD002", "item_index": None, "quantity": "1"},
137
+ ],
138
+ [{"product_id": "PROD001", "item_index": 12345, "quantity": "1"}],
139
+ ],
140
+ "ground_truth": [[1, 2, 3], [1, 2]],
141
+ "embeddings": [np.random.randn(128), np.random.randn(128)],
142
+ }
143
+ fs_save_prediction_output_parquet()(saved_data, uri)
144
+
145
+ dataset = ParquetDataset(uri) # dtype_backend: numpy_nullable
146
+ print(type(dataset[0]["embeddings"])) # type: numpy.ndarray
147
+ print(type(dataset[0]["purchased_items"])) # type: numpy.ndarray of object
148
+ print(type(dataset[0]["ground_truth"])) # type: numpy.ndarray of object
149
+
150
+ pyarrow_dataset = ParquetDataset(uri, dtype_backend="pyarrow") # dtype_backend: pyarrow
151
+ print(type(pyarrow_dataset[0]["embeddings"])) # type: list
152
+ print(type(pyarrow_dataset[0]["purchased_items"])) # type: list of dictionary
153
+ print(type(pyarrow_dataset[0]["ground_truth"])) # type: list of int
154
+
155
+ Args:
156
+ uri (str | list[str]): URI of Parquet data.
157
+ columns (list[str] | None): Columns to load.
158
+ use_threads (bool): Use multi-threaded processing.
159
+ Defaults to ``True``.
160
+ replace_nan (bool): Whether to replace np.nan as None.
161
+ Default to ``True``
162
+ s3wr_args (dict): config for s3wr.s3.read_parquet,
163
+ refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ uri: str | list[str],
169
+ columns: list[str] | None = None,
170
+ use_threads: bool = True,
171
+ replace_nan: bool = True,
172
+ **s3wr_args: Any,
173
+ ) -> None:
174
+ # check FileSystem
175
+ fs: FileSystem
176
+ path: str
177
+ if isinstance(uri, str):
178
+ fs, path = FileSystem.from_uri(uri)
179
+ else:
180
+ fs, path = FileSystem.from_uri(uri[0])
181
+ if isinstance(fs, S3FileSystem):
182
+ # if file is in S3, then use awswrangler
183
+ self.df = s3wr.s3.read_parquet(
184
+ uri,
185
+ use_threads=use_threads,
186
+ columns=columns,
187
+ boto3_session=session(clients=["s3"]),
188
+ path_suffix="parquet",
189
+ **s3wr_args,
190
+ )
191
+ if replace_nan:
192
+ self.df = self.df.replace({np.nan: None})
193
+ else:
194
+ # otherwise, use pyarrow
195
+ path_list = []
196
+ if isinstance(uri, str):
197
+ fs, path = FileSystem.from_uri(uri)
198
+ path_list.append(path)
199
+ elif isinstance(uri, list):
200
+ path_list = []
201
+ for each in uri:
202
+ _, path = FileSystem.from_uri(each)
203
+ path_list.append(path)
204
+ else:
205
+ raise Exception(f"ParquetDataset can't support uri as {type(uri)}")
206
+ pds = pa.parquet.ParquetDataset(path_list, filesystem=fs) # type: ignore
207
+ tbl = pds.read(columns, use_threads=use_threads, use_pandas_metadata=False)
208
+ self.df = tbl.to_pandas()
209
+
210
+ def __len__(self) -> int:
211
+ """Get :class:`Dataset` size.
212
+
213
+ Returns:
214
+ int: :class:`Dataset` size.
215
+ """
216
+ return len(self.df)
217
+
218
+ def __getitems__(self, idxs: list[int]) -> list[dict[str, Any]]:
219
+ """Get a batch of samples at the specified indices.
220
+
221
+ Args:
222
+ idxs (list[int]): Samples' indices.
223
+
224
+ Returns:
225
+ list[dict[str, Any]]: A batch of samples.
226
+ """
227
+ series = self.df.iloc[idxs]
228
+ samples = [series.iloc[i].to_dict() for i in range(len(idxs))]
229
+ return samples
230
+
231
+ def __getitem__(self, idx: int) -> dict[str, Any]:
232
+ """Get a sample at the specified index.
233
+
234
+ Args:
235
+ idx (int): Sample index.
236
+
237
+ Returns:
238
+ dict[str, Any]: A sample.
239
+ """
240
+ series = self.df.iloc[idx]
241
+ sample = series.to_dict()
242
+ return sample
@@ -0,0 +1,31 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import TypeVar
4
+
5
+ from typing_extensions import Protocol
6
+
7
+ T_in = TypeVar("T_in", contravariant=True)
8
+ T_out = TypeVar("T_out", covariant=True)
9
+
10
+
11
+ class SizedDataset(Protocol[T_in, T_out]):
12
+ """A :class:`Dataset` with a known size."""
13
+
14
+ def __len__(self) -> int:
15
+ """Get :class:`Dataset` size.
16
+
17
+ Returns:
18
+ int: :class:`Dataset` size.
19
+ """
20
+ ...
21
+
22
+ def __getitem__(self, idx: T_in) -> T_out:
23
+ """Get a sample at the specified index.
24
+
25
+ Args:
26
+ idx (T_in): Sample index.
27
+
28
+ Returns:
29
+ T_out: A sample.
30
+ """
31
+ ...
fkat/data/dict.py ADDED
@@ -0,0 +1,42 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+ from collections.abc import Iterable, Iterator
5
+
6
+ from typing_extensions import override
7
+
8
+ from fkat.data.samplers.strategies import SamplerStrategy
9
+ from fkat.utils.config import to_primitive_container
10
+
11
+
12
+ def wrap(key: str, name: str, batch: dict[str, Any]) -> dict[str, Any]:
13
+ if not isinstance(batch, dict) or key in batch:
14
+ raise RuntimeError(f"DataLoaders must return a batch dict without {key} key")
15
+ batch[key] = name
16
+ return batch
17
+
18
+
19
+ class DictDataLoader(Iterable[dict[str, Any]]):
20
+ """A :class:`LightningDataModule` that manages multiple :class:`DataLoader`\\s for different stages."""
21
+
22
+ def __init__(
23
+ self,
24
+ dataloaders: dict[str, Iterable[dict[str, Any]]],
25
+ strategy: SamplerStrategy,
26
+ key: str = "dataset",
27
+ ) -> None:
28
+ self.dataloaders = to_primitive_container(dataloaders)
29
+ self.strategy = strategy
30
+ self.key = key
31
+
32
+ @override
33
+ def __iter__(self) -> Iterator[dict[str, Any]]:
34
+ iters = {k: iter(self.dataloaders[k]) for k in self.dataloaders}
35
+ for name in self.strategy:
36
+ if not iters:
37
+ return
38
+ if it := iters.get(name):
39
+ try:
40
+ yield wrap(self.key, name, next(it))
41
+ except StopIteration:
42
+ del iters[name]
@@ -0,0 +1,9 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .dict import DictBatchSampler
4
+ from .sized import SizedSampler
5
+
6
+ __all__ = [
7
+ "SizedSampler",
8
+ "DictBatchSampler",
9
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from collections.abc import Iterator
4
+
5
+ from typing_extensions import override
6
+
7
+ from fkat.data.samplers.sized import SizedSampler
8
+ from fkat.data.samplers.strategies import SamplerStrategy
9
+
10
+
11
+ class DictBatchSampler(SizedSampler[tuple[str, list[int]]]):
12
+ def __init__(self, strategy: SamplerStrategy, samplers: dict[str, SizedSampler[list[int]]]) -> None:
13
+ self.strategy = strategy
14
+ self.samplers = samplers
15
+ self.len = sum(len(sampler) for sampler in samplers.values())
16
+
17
+ @override
18
+ def __len__(self) -> int:
19
+ return self.len
20
+
21
+ @override
22
+ def __iter__(self) -> Iterator[tuple[str, list[int]]]:
23
+ rem_samplers = {name: iter(sampler) for name, sampler in self.samplers.items()}
24
+ for key in self.strategy:
25
+ if not rem_samplers:
26
+ # no more iterators left, stopping iteration
27
+ return
28
+ if key not in rem_samplers:
29
+ # this sampler is exhausted, skipping to sample the remaining ones
30
+ # this for example will adjust the effective weights for the remaining ones when sampling
31
+ # or sample next ones for sequential or frequency-based samplers
32
+ continue
33
+ try:
34
+ batch = next(rem_samplers[key])
35
+ yield (key, batch)
36
+ except StopIteration:
37
+ # this sampler is exhausted, removing from consideration
38
+ del rem_samplers[key]
@@ -0,0 +1,16 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import TypeVar
4
+ from collections.abc import Iterator
5
+
6
+ from typing_extensions import Protocol
7
+
8
+ T_co = TypeVar("T_co", covariant=True)
9
+
10
+
11
+ class SizedSampler(Protocol[T_co]):
12
+ """A Sampler with a known size."""
13
+
14
+ def __len__(self) -> int: ...
15
+
16
+ def __iter__(self) -> Iterator[T_co]: ...
@@ -0,0 +1,68 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+ from collections.abc import Iterator
5
+
6
+ import numpy as np
7
+ from typing_extensions import Protocol, override
8
+
9
+
10
+ class SamplerStrategy(Protocol):
11
+ """
12
+ This strategy decides which Sampler to use next and returns its label
13
+ """
14
+
15
+ def __iter__(self) -> Iterator[str]: ...
16
+
17
+
18
+ class Weighted(SamplerStrategy):
19
+ """
20
+ Sample the label to generate next microbatch from using the provided weight distribution.
21
+ For uniform distribution use the same weight for all labels.
22
+ """
23
+
24
+ def __init__(self, weights: dict[str, float]) -> None:
25
+ self.names = []
26
+ self.weights = []
27
+ for name, weight in weights.items():
28
+ self.names.append(name)
29
+ self.weights.append(weight)
30
+
31
+ @override
32
+ def __iter__(self) -> Iterator[str]:
33
+ while True:
34
+ yield np.random.choice(self.names, p=self.weights)
35
+
36
+
37
+ class RoundRobin(SamplerStrategy):
38
+ """
39
+ Specifies the order of labels to generate microbatches from.
40
+ """
41
+
42
+ def __init__(self, names: list[str]) -> None:
43
+ self.names = names
44
+
45
+ @override
46
+ def __iter__(self) -> Iterator[str]:
47
+ i = -1
48
+ while True:
49
+ i = (i + 1) % len(self.names)
50
+ yield self.names[i]
51
+
52
+
53
+ class Frequency(SamplerStrategy):
54
+ """
55
+ Specifies the order and number of microbatches to generate for specific labels.
56
+ E.g. [["first", 2], ["second", 1], ["first", 3], ["third", 1]]
57
+ """
58
+
59
+ def __init__(self, freq: list[list[Any]]) -> None:
60
+ assert all(isinstance(e[0], str) and isinstance(e[1], int) for e in freq)
61
+ self.freq = freq
62
+
63
+ @override
64
+ def __iter__(self) -> Iterator[str]:
65
+ while True:
66
+ for name, count in self.freq:
67
+ for _ in range(count):
68
+ yield name