fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
from collections.abc import Callable, Iterator
|
|
5
|
+
|
|
6
|
+
from fkat.data.datasets import SizedDataset
|
|
7
|
+
from torch.utils.data import IterableDataset
|
|
8
|
+
|
|
9
|
+
T_in = TypeVar("T_in", contravariant=True)
|
|
10
|
+
T_from = TypeVar("T_from", covariant=True)
|
|
11
|
+
T_to = TypeVar("T_to", covariant=True)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MapDataset(SizedDataset[T_in, T_to]):
|
|
15
|
+
"""A :class:`Dataset` that transforms the samples from another :class:`Dataset` using a function."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
dataset: SizedDataset[T_in, T_from],
|
|
20
|
+
fn: Callable[[T_from], T_to],
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Create a :class:`Dataset` that maps samples of another :class:`Dataset` using a function.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dataset (SizedDataset): Source :class:`Dataset`.
|
|
26
|
+
fn (Callable[[T_from], T_to]): Sample transformation function.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
None
|
|
30
|
+
"""
|
|
31
|
+
self.dataset = dataset
|
|
32
|
+
self.fn = fn
|
|
33
|
+
|
|
34
|
+
def __len__(self) -> int:
|
|
35
|
+
"""Get :class:`Dataset` size.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
int: :class:`Dataset` size.
|
|
39
|
+
"""
|
|
40
|
+
return len(self.dataset)
|
|
41
|
+
|
|
42
|
+
def __getitems__(self, idxs: list[T_in]) -> list[T_to]:
|
|
43
|
+
"""Get a batch of samples at the specified indices.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
idxs (List[T_in]): Samples' indices.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
List[T_to]: A batch of samples.
|
|
50
|
+
"""
|
|
51
|
+
if getitems := getattr(self.dataset, "__getitems__", None):
|
|
52
|
+
batch = getitems(idxs)
|
|
53
|
+
else:
|
|
54
|
+
batch = [self.dataset[idx] for idx in idxs]
|
|
55
|
+
samples = [self.fn(sample) for sample in batch]
|
|
56
|
+
return samples
|
|
57
|
+
|
|
58
|
+
def __getitem__(self, idx: T_in) -> T_to:
|
|
59
|
+
"""Get a sample at the specified index.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
idx (T_in): Sample index.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
T_to: A sample.
|
|
66
|
+
"""
|
|
67
|
+
sample = self.fn(self.dataset[idx])
|
|
68
|
+
return sample
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class IterableMapDataset(IterableDataset[T_to]):
|
|
72
|
+
"""An :class:`IterableDataset` that transforms the samples from another
|
|
73
|
+
:class:`IterableDataset` using a function."""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
dataset: IterableDataset[T_from],
|
|
78
|
+
fn: Callable[[T_from], T_to],
|
|
79
|
+
) -> None:
|
|
80
|
+
self.dataset = dataset
|
|
81
|
+
self.fn = fn
|
|
82
|
+
|
|
83
|
+
def __iter__(self) -> Iterator[T_to]:
|
|
84
|
+
"""Get :class:`IterableDataset` iterator.
|
|
85
|
+
|
|
86
|
+
Yields:
|
|
87
|
+
T_to: A sample.
|
|
88
|
+
"""
|
|
89
|
+
for sample in iter(self.dataset):
|
|
90
|
+
yield self.fn(sample)
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
import awswrangler as s3wr
|
|
6
|
+
from pyarrow.fs import FileSystem, S3FileSystem # type: ignore[possibly-unbound-import]
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pyarrow as pa
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from fkat.data.datasets import SizedDataset
|
|
12
|
+
from fkat.utils.pyarrow import iter_rows as pa_iter_rows
|
|
13
|
+
from fkat.utils.pandas import iter_rows as pd_iter_rows
|
|
14
|
+
from fkat.utils.boto3 import session
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class IterableParquetDataset(IterableDataset[dict[str, Any]]):
|
|
18
|
+
"""
|
|
19
|
+
An :class:`IterableDataset` backed by Parquet data.
|
|
20
|
+
|
|
21
|
+
.. note:: If you want to keep the original type from reading parquet,
|
|
22
|
+
you should set ``dtype_backend='pyarrow'``.
|
|
23
|
+
|
|
24
|
+
example config:
|
|
25
|
+
|
|
26
|
+
.. code-block:: yaml
|
|
27
|
+
|
|
28
|
+
_target_: fkat.data.datasets.parquet.IterableParquetDataset
|
|
29
|
+
uri: s3://path/to/fkat.parquet
|
|
30
|
+
dtype_backend: pyarrow
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
uri (str or list[str]): URI of Parquet data.
|
|
34
|
+
columns (List[str], optional): Columns to load.
|
|
35
|
+
Default to ``None``
|
|
36
|
+
use_threads (bool): Use multi-threaded processing.
|
|
37
|
+
Default to ``True``.
|
|
38
|
+
chunk_size (int): An iterable of DataFrames is returned with maximum rows equal to the received INTEGER.
|
|
39
|
+
Default to ``10000``
|
|
40
|
+
replace_nan (bool): Whether to replace np.nan as None.
|
|
41
|
+
Default to ``True``
|
|
42
|
+
s3wr_args (dict): config for s3wr.s3.read_parquet,
|
|
43
|
+
refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
uri: str | list[str],
|
|
49
|
+
columns: list[str] | None = None,
|
|
50
|
+
use_threads: bool = True,
|
|
51
|
+
chunk_size: int = 10000,
|
|
52
|
+
replace_nan: bool = True,
|
|
53
|
+
**s3wr_args: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
fs: FileSystem
|
|
56
|
+
path: str
|
|
57
|
+
if isinstance(uri, str):
|
|
58
|
+
fs, path = FileSystem.from_uri(uri)
|
|
59
|
+
else:
|
|
60
|
+
fs, path = FileSystem.from_uri(uri[0])
|
|
61
|
+
self.s3_file = isinstance(fs, S3FileSystem)
|
|
62
|
+
if self.s3_file:
|
|
63
|
+
self.uri = uri
|
|
64
|
+
self.use_threads = use_threads
|
|
65
|
+
self.columns = columns
|
|
66
|
+
self.chunk_size = chunk_size
|
|
67
|
+
self.replace_nan = replace_nan
|
|
68
|
+
self.s3wr_args = s3wr_args
|
|
69
|
+
else:
|
|
70
|
+
# otherwise, use pyarrow
|
|
71
|
+
path_list = []
|
|
72
|
+
if isinstance(uri, str):
|
|
73
|
+
_, path = FileSystem.from_uri(uri)
|
|
74
|
+
path_list.append(path)
|
|
75
|
+
else:
|
|
76
|
+
path_list = []
|
|
77
|
+
for each in uri:
|
|
78
|
+
_, path = FileSystem.from_uri(each)
|
|
79
|
+
path_list.append(path)
|
|
80
|
+
pds = pa.parquet.ParquetDataset(path_list, filesystem=fs) # type: ignore
|
|
81
|
+
self.tbl = pds.read(columns, use_threads=use_threads, use_pandas_metadata=False)
|
|
82
|
+
self.chunk_size = chunk_size
|
|
83
|
+
|
|
84
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
85
|
+
"""Creates dataset iterator.
|
|
86
|
+
Returns:
|
|
87
|
+
Iterator[dict[str, Any]]: dataset iterator
|
|
88
|
+
"""
|
|
89
|
+
if self.s3_file:
|
|
90
|
+
return pd_iter_rows(
|
|
91
|
+
s3wr.s3.read_parquet(
|
|
92
|
+
self.uri,
|
|
93
|
+
use_threads=self.use_threads,
|
|
94
|
+
columns=self.columns,
|
|
95
|
+
chunked=self.chunk_size,
|
|
96
|
+
boto3_session=session(clients=["s3"]),
|
|
97
|
+
path_suffix="parquet",
|
|
98
|
+
**self.s3wr_args,
|
|
99
|
+
),
|
|
100
|
+
self.replace_nan,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
return pa_iter_rows(self.tbl, chunk_size=self.chunk_size)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ParquetDataset(SizedDataset[int, dict[str, Any]]):
|
|
107
|
+
"""
|
|
108
|
+
A :class:`Dataset` backed by Parquet data.
|
|
109
|
+
|
|
110
|
+
Create a :class:`Dataset` from Parquet data at the specified URI.
|
|
111
|
+
|
|
112
|
+
.. note:: If you want to keep the original type from reading parquet,
|
|
113
|
+
you should set ``dtype_backend='pyarrow'``.
|
|
114
|
+
|
|
115
|
+
example config:
|
|
116
|
+
|
|
117
|
+
.. code-block:: yaml
|
|
118
|
+
|
|
119
|
+
_target_: fkat.data.datasets.parquet.ParquetDataset
|
|
120
|
+
uri: s3://path/to/fkat.parquet
|
|
121
|
+
dtype_backend: pyarrow
|
|
122
|
+
|
|
123
|
+
Difference for ``dtype_backend`` between ``pyarrow`` and ``numpy_nullable``
|
|
124
|
+
|
|
125
|
+
.. code-block:: python
|
|
126
|
+
|
|
127
|
+
from fkat.data.datasets.parquet import ParquetDataset
|
|
128
|
+
from fkat.utils.s3_utils import fs_save_prediction_output_parquet
|
|
129
|
+
|
|
130
|
+
uri = "s3://path/to/fkat.parquet"
|
|
131
|
+
|
|
132
|
+
saved_data = {
|
|
133
|
+
"purchased_items": [
|
|
134
|
+
[
|
|
135
|
+
{"product_id": "PROD001", "item_index": 12345, "quantity": "1"},
|
|
136
|
+
{"product_id": "PROD002", "item_index": None, "quantity": "1"},
|
|
137
|
+
],
|
|
138
|
+
[{"product_id": "PROD001", "item_index": 12345, "quantity": "1"}],
|
|
139
|
+
],
|
|
140
|
+
"ground_truth": [[1, 2, 3], [1, 2]],
|
|
141
|
+
"embeddings": [np.random.randn(128), np.random.randn(128)],
|
|
142
|
+
}
|
|
143
|
+
fs_save_prediction_output_parquet()(saved_data, uri)
|
|
144
|
+
|
|
145
|
+
dataset = ParquetDataset(uri) # dtype_backend: numpy_nullable
|
|
146
|
+
print(type(dataset[0]["embeddings"])) # type: numpy.ndarray
|
|
147
|
+
print(type(dataset[0]["purchased_items"])) # type: numpy.ndarray of object
|
|
148
|
+
print(type(dataset[0]["ground_truth"])) # type: numpy.ndarray of object
|
|
149
|
+
|
|
150
|
+
pyarrow_dataset = ParquetDataset(uri, dtype_backend="pyarrow") # dtype_backend: pyarrow
|
|
151
|
+
print(type(pyarrow_dataset[0]["embeddings"])) # type: list
|
|
152
|
+
print(type(pyarrow_dataset[0]["purchased_items"])) # type: list of dictionary
|
|
153
|
+
print(type(pyarrow_dataset[0]["ground_truth"])) # type: list of int
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
uri (str | list[str]): URI of Parquet data.
|
|
157
|
+
columns (list[str] | None): Columns to load.
|
|
158
|
+
use_threads (bool): Use multi-threaded processing.
|
|
159
|
+
Defaults to ``True``.
|
|
160
|
+
replace_nan (bool): Whether to replace np.nan as None.
|
|
161
|
+
Default to ``True``
|
|
162
|
+
s3wr_args (dict): config for s3wr.s3.read_parquet,
|
|
163
|
+
refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
uri: str | list[str],
|
|
169
|
+
columns: list[str] | None = None,
|
|
170
|
+
use_threads: bool = True,
|
|
171
|
+
replace_nan: bool = True,
|
|
172
|
+
**s3wr_args: Any,
|
|
173
|
+
) -> None:
|
|
174
|
+
# check FileSystem
|
|
175
|
+
fs: FileSystem
|
|
176
|
+
path: str
|
|
177
|
+
if isinstance(uri, str):
|
|
178
|
+
fs, path = FileSystem.from_uri(uri)
|
|
179
|
+
else:
|
|
180
|
+
fs, path = FileSystem.from_uri(uri[0])
|
|
181
|
+
if isinstance(fs, S3FileSystem):
|
|
182
|
+
# if file is in S3, then use awswrangler
|
|
183
|
+
self.df = s3wr.s3.read_parquet(
|
|
184
|
+
uri,
|
|
185
|
+
use_threads=use_threads,
|
|
186
|
+
columns=columns,
|
|
187
|
+
boto3_session=session(clients=["s3"]),
|
|
188
|
+
path_suffix="parquet",
|
|
189
|
+
**s3wr_args,
|
|
190
|
+
)
|
|
191
|
+
if replace_nan:
|
|
192
|
+
self.df = self.df.replace({np.nan: None})
|
|
193
|
+
else:
|
|
194
|
+
# otherwise, use pyarrow
|
|
195
|
+
path_list = []
|
|
196
|
+
if isinstance(uri, str):
|
|
197
|
+
fs, path = FileSystem.from_uri(uri)
|
|
198
|
+
path_list.append(path)
|
|
199
|
+
elif isinstance(uri, list):
|
|
200
|
+
path_list = []
|
|
201
|
+
for each in uri:
|
|
202
|
+
_, path = FileSystem.from_uri(each)
|
|
203
|
+
path_list.append(path)
|
|
204
|
+
else:
|
|
205
|
+
raise Exception(f"ParquetDataset can't support uri as {type(uri)}")
|
|
206
|
+
pds = pa.parquet.ParquetDataset(path_list, filesystem=fs) # type: ignore
|
|
207
|
+
tbl = pds.read(columns, use_threads=use_threads, use_pandas_metadata=False)
|
|
208
|
+
self.df = tbl.to_pandas()
|
|
209
|
+
|
|
210
|
+
def __len__(self) -> int:
|
|
211
|
+
"""Get :class:`Dataset` size.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
int: :class:`Dataset` size.
|
|
215
|
+
"""
|
|
216
|
+
return len(self.df)
|
|
217
|
+
|
|
218
|
+
def __getitems__(self, idxs: list[int]) -> list[dict[str, Any]]:
|
|
219
|
+
"""Get a batch of samples at the specified indices.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
idxs (list[int]): Samples' indices.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
list[dict[str, Any]]: A batch of samples.
|
|
226
|
+
"""
|
|
227
|
+
series = self.df.iloc[idxs]
|
|
228
|
+
samples = [series.iloc[i].to_dict() for i in range(len(idxs))]
|
|
229
|
+
return samples
|
|
230
|
+
|
|
231
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
232
|
+
"""Get a sample at the specified index.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
idx (int): Sample index.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
dict[str, Any]: A sample.
|
|
239
|
+
"""
|
|
240
|
+
series = self.df.iloc[idx]
|
|
241
|
+
sample = series.to_dict()
|
|
242
|
+
return sample
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
5
|
+
from typing_extensions import Protocol
|
|
6
|
+
|
|
7
|
+
T_in = TypeVar("T_in", contravariant=True)
|
|
8
|
+
T_out = TypeVar("T_out", covariant=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SizedDataset(Protocol[T_in, T_out]):
|
|
12
|
+
"""A :class:`Dataset` with a known size."""
|
|
13
|
+
|
|
14
|
+
def __len__(self) -> int:
|
|
15
|
+
"""Get :class:`Dataset` size.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
int: :class:`Dataset` size.
|
|
19
|
+
"""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def __getitem__(self, idx: T_in) -> T_out:
|
|
23
|
+
"""Get a sample at the specified index.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
idx (T_in): Sample index.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
T_out: A sample.
|
|
30
|
+
"""
|
|
31
|
+
...
|
fkat/data/dict.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any
|
|
4
|
+
from collections.abc import Iterable, Iterator
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from fkat.data.samplers.strategies import SamplerStrategy
|
|
9
|
+
from fkat.utils.config import to_primitive_container
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def wrap(key: str, name: str, batch: dict[str, Any]) -> dict[str, Any]:
|
|
13
|
+
if not isinstance(batch, dict) or key in batch:
|
|
14
|
+
raise RuntimeError(f"DataLoaders must return a batch dict without {key} key")
|
|
15
|
+
batch[key] = name
|
|
16
|
+
return batch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DictDataLoader(Iterable[dict[str, Any]]):
|
|
20
|
+
"""A :class:`LightningDataModule` that manages multiple :class:`DataLoader`\\s for different stages."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dataloaders: dict[str, Iterable[dict[str, Any]]],
|
|
25
|
+
strategy: SamplerStrategy,
|
|
26
|
+
key: str = "dataset",
|
|
27
|
+
) -> None:
|
|
28
|
+
self.dataloaders = to_primitive_container(dataloaders)
|
|
29
|
+
self.strategy = strategy
|
|
30
|
+
self.key = key
|
|
31
|
+
|
|
32
|
+
@override
|
|
33
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
34
|
+
iters = {k: iter(self.dataloaders[k]) for k in self.dataloaders}
|
|
35
|
+
for name in self.strategy:
|
|
36
|
+
if not iters:
|
|
37
|
+
return
|
|
38
|
+
if it := iters.get(name):
|
|
39
|
+
try:
|
|
40
|
+
yield wrap(self.key, name, next(it))
|
|
41
|
+
except StopIteration:
|
|
42
|
+
del iters[name]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fkat.data.samplers.sized import SizedSampler
|
|
8
|
+
from fkat.data.samplers.strategies import SamplerStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DictBatchSampler(SizedSampler[tuple[str, list[int]]]):
|
|
12
|
+
def __init__(self, strategy: SamplerStrategy, samplers: dict[str, SizedSampler[list[int]]]) -> None:
|
|
13
|
+
self.strategy = strategy
|
|
14
|
+
self.samplers = samplers
|
|
15
|
+
self.len = sum(len(sampler) for sampler in samplers.values())
|
|
16
|
+
|
|
17
|
+
@override
|
|
18
|
+
def __len__(self) -> int:
|
|
19
|
+
return self.len
|
|
20
|
+
|
|
21
|
+
@override
|
|
22
|
+
def __iter__(self) -> Iterator[tuple[str, list[int]]]:
|
|
23
|
+
rem_samplers = {name: iter(sampler) for name, sampler in self.samplers.items()}
|
|
24
|
+
for key in self.strategy:
|
|
25
|
+
if not rem_samplers:
|
|
26
|
+
# no more iterators left, stopping iteration
|
|
27
|
+
return
|
|
28
|
+
if key not in rem_samplers:
|
|
29
|
+
# this sampler is exhausted, skipping to sample the remaining ones
|
|
30
|
+
# this for example will adjust the effective weights for the remaining ones when sampling
|
|
31
|
+
# or sample next ones for sequential or frequency-based samplers
|
|
32
|
+
continue
|
|
33
|
+
try:
|
|
34
|
+
batch = next(rem_samplers[key])
|
|
35
|
+
yield (key, batch)
|
|
36
|
+
except StopIteration:
|
|
37
|
+
# this sampler is exhausted, removing from consideration
|
|
38
|
+
del rem_samplers[key]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Protocol
|
|
7
|
+
|
|
8
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SizedSampler(Protocol[T_co]):
|
|
12
|
+
"""A Sampler with a known size."""
|
|
13
|
+
|
|
14
|
+
def __len__(self) -> int: ...
|
|
15
|
+
|
|
16
|
+
def __iter__(self) -> Iterator[T_co]: ...
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing_extensions import Protocol, override
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SamplerStrategy(Protocol):
|
|
11
|
+
"""
|
|
12
|
+
This strategy decides which Sampler to use next and returns its label
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __iter__(self) -> Iterator[str]: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Weighted(SamplerStrategy):
|
|
19
|
+
"""
|
|
20
|
+
Sample the label to generate next microbatch from using the provided weight distribution.
|
|
21
|
+
For uniform distribution use the same weight for all labels.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, weights: dict[str, float]) -> None:
|
|
25
|
+
self.names = []
|
|
26
|
+
self.weights = []
|
|
27
|
+
for name, weight in weights.items():
|
|
28
|
+
self.names.append(name)
|
|
29
|
+
self.weights.append(weight)
|
|
30
|
+
|
|
31
|
+
@override
|
|
32
|
+
def __iter__(self) -> Iterator[str]:
|
|
33
|
+
while True:
|
|
34
|
+
yield np.random.choice(self.names, p=self.weights)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RoundRobin(SamplerStrategy):
|
|
38
|
+
"""
|
|
39
|
+
Specifies the order of labels to generate microbatches from.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, names: list[str]) -> None:
|
|
43
|
+
self.names = names
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def __iter__(self) -> Iterator[str]:
|
|
47
|
+
i = -1
|
|
48
|
+
while True:
|
|
49
|
+
i = (i + 1) % len(self.names)
|
|
50
|
+
yield self.names[i]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Frequency(SamplerStrategy):
|
|
54
|
+
"""
|
|
55
|
+
Specifies the order and number of microbatches to generate for specific labels.
|
|
56
|
+
E.g. [["first", 2], ["second", 1], ["first", 3], ["third", 1]]
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, freq: list[list[Any]]) -> None:
|
|
60
|
+
assert all(isinstance(e[0], str) and isinstance(e[1], int) for e in freq)
|
|
61
|
+
self.freq = freq
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def __iter__(self) -> Iterator[str]:
|
|
65
|
+
while True:
|
|
66
|
+
for name, count in self.freq:
|
|
67
|
+
for _ in range(count):
|
|
68
|
+
yield name
|