pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
kito/data/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# src/kito/data/__init__.py
|
|
2
|
+
"""Kito Data - Datasets, preprocessing, and pipelines"""
|
|
3
|
+
|
|
4
|
+
from kito.data.datasets import KitoDataset, H5Dataset, MemDataset
|
|
5
|
+
from kito.data.preprocessing import (
|
|
6
|
+
Preprocessing,
|
|
7
|
+
Pipeline,
|
|
8
|
+
Normalize,
|
|
9
|
+
Standardization,
|
|
10
|
+
ToTensor
|
|
11
|
+
)
|
|
12
|
+
from kito.data.datapipeline import GenericDataPipeline, BaseDataPipeline
|
|
13
|
+
from kito.data.registry import DATASETS, PREPROCESSING
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
# Datasets
|
|
17
|
+
"KitoDataset",
|
|
18
|
+
"H5Dataset",
|
|
19
|
+
"MemDataset",
|
|
20
|
+
|
|
21
|
+
# Preprocessing
|
|
22
|
+
"Preprocessing",
|
|
23
|
+
"Pipeline",
|
|
24
|
+
"Normalize",
|
|
25
|
+
"Standardization",
|
|
26
|
+
"ToTensor",
|
|
27
|
+
|
|
28
|
+
# Pipelines
|
|
29
|
+
"GenericDataPipeline",
|
|
30
|
+
"BaseDataPipeline",
|
|
31
|
+
|
|
32
|
+
# Registries
|
|
33
|
+
"DATASETS",
|
|
34
|
+
"PREPROCESSING",
|
|
35
|
+
]
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DataPipeline classes - Orchestrate data pipeline.
|
|
3
|
+
|
|
4
|
+
DataPipeline encapsulates:
|
|
5
|
+
- Dataset creation
|
|
6
|
+
- Preprocessing setup
|
|
7
|
+
- Train/val/test splitting
|
|
8
|
+
- DataLoader configuration
|
|
9
|
+
|
|
10
|
+
This separates data concerns from model training logic.
|
|
11
|
+
"""
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from torch.utils.data import DataLoader, Subset, DistributedSampler
|
|
16
|
+
|
|
17
|
+
from kito.config.moduleconfig import KitoModuleConfig
|
|
18
|
+
from kito.data.datasets import KitoDataset
|
|
19
|
+
from kito.data.preprocessed_dataset import PreprocessedDataset
|
|
20
|
+
from kito.data.preprocessing import Preprocessing
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseDataPipeline(ABC):
|
|
24
|
+
"""
|
|
25
|
+
Base class for data modules.
|
|
26
|
+
|
|
27
|
+
A DataModule encapsulates all data-related logic:
|
|
28
|
+
- How to load data (dataset)
|
|
29
|
+
- How to preprocess data
|
|
30
|
+
- How to split into train/val/test
|
|
31
|
+
- How to create DataLoaders
|
|
32
|
+
|
|
33
|
+
Usage:
|
|
34
|
+
class MyDataModule(BaseDataModule):
|
|
35
|
+
def setup(self):
|
|
36
|
+
self.dataset = MyDataset(self.data_config.dataset_path)
|
|
37
|
+
# ... setup preprocessing, splits, etc.
|
|
38
|
+
|
|
39
|
+
def train_dataloader(self):
|
|
40
|
+
return DataLoader(self.train_dataset, ...)
|
|
41
|
+
|
|
42
|
+
# In training
|
|
43
|
+
dm = MyDataModule(config)
|
|
44
|
+
dm.setup()
|
|
45
|
+
engine.fit(datamodule=dm)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: KitoModuleConfig):
|
|
49
|
+
"""
|
|
50
|
+
Initialize DataModule.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: Full module configuration (contains data config)
|
|
54
|
+
"""
|
|
55
|
+
self.config = config
|
|
56
|
+
self.data_config = config.data
|
|
57
|
+
|
|
58
|
+
# Will be set by setup()
|
|
59
|
+
self.dataset = None
|
|
60
|
+
self.train_dataset = None
|
|
61
|
+
self.val_dataset = None
|
|
62
|
+
self.test_dataset = None
|
|
63
|
+
|
|
64
|
+
self.train_loader = None
|
|
65
|
+
self.val_loader = None
|
|
66
|
+
self.test_loader = None
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def setup(self):
|
|
70
|
+
"""
|
|
71
|
+
Setup datasets and preprocessing.
|
|
72
|
+
|
|
73
|
+
This method should:
|
|
74
|
+
1. Load raw dataset
|
|
75
|
+
2. Apply preprocessing (if any)
|
|
76
|
+
3. Create train/val/test splits
|
|
77
|
+
4. Create DataLoaders
|
|
78
|
+
|
|
79
|
+
Called once before training.
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def train_dataloader(self) -> DataLoader:
|
|
85
|
+
"""Return training DataLoader."""
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def val_dataloader(self) -> DataLoader:
|
|
90
|
+
"""Return validation DataLoader."""
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def test_dataloader(self) -> DataLoader:
|
|
95
|
+
"""Return test DataLoader."""
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class GenericDataPipeline(BaseDataPipeline):
|
|
100
|
+
"""
|
|
101
|
+
Generic DataModule that works with any registered dataset.
|
|
102
|
+
|
|
103
|
+
This is the standard DataModule used by DataModuleFactory.
|
|
104
|
+
It handles:
|
|
105
|
+
- Loading dataset from registry
|
|
106
|
+
- Applying preprocessing pipeline
|
|
107
|
+
- Creating train/val/test splits
|
|
108
|
+
- Setting up DataLoaders with DDP support
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
config: Full module configuration
|
|
112
|
+
dataset: Pre-instantiated dataset (optional)
|
|
113
|
+
preprocessing: Pre-instantiated preprocessing (optional)
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
>>> dm = GenericDataModule(config)
|
|
117
|
+
>>> dm.setup()
|
|
118
|
+
>>> train_loader = dm.train_dataloader()
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
config: KitoModuleConfig,
|
|
124
|
+
dataset: Optional[KitoDataset] = None,
|
|
125
|
+
preprocessing: Optional[Preprocessing] = None
|
|
126
|
+
):
|
|
127
|
+
super().__init__(config)
|
|
128
|
+
self._dataset = dataset # Pre-instantiated dataset
|
|
129
|
+
self._preprocessing = preprocessing # Pre-instantiated preprocessing
|
|
130
|
+
|
|
131
|
+
def setup(self):
|
|
132
|
+
"""
|
|
133
|
+
Setup data pipeline.
|
|
134
|
+
|
|
135
|
+
Steps:
|
|
136
|
+
1. Use pre-instantiated dataset or load from registry
|
|
137
|
+
2. Wrap with preprocessing if provided
|
|
138
|
+
3. Create train/val/test splits
|
|
139
|
+
4. Create DataLoaders
|
|
140
|
+
"""
|
|
141
|
+
# 1. Get dataset
|
|
142
|
+
if self._dataset is not None:
|
|
143
|
+
raw_dataset = self._dataset
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError("Dataset not provided to GenericDataModule")
|
|
146
|
+
|
|
147
|
+
# 2. Apply preprocessing
|
|
148
|
+
if self._preprocessing is not None:
|
|
149
|
+
self.dataset = PreprocessedDataset(raw_dataset, self._preprocessing)
|
|
150
|
+
else:
|
|
151
|
+
self.dataset = raw_dataset
|
|
152
|
+
|
|
153
|
+
# 3. Create splits
|
|
154
|
+
self._create_splits()
|
|
155
|
+
|
|
156
|
+
# 4. Create DataLoaders
|
|
157
|
+
self._create_dataloaders()
|
|
158
|
+
|
|
159
|
+
def _create_splits(self):
|
|
160
|
+
"""
|
|
161
|
+
Create train/val/test splits using Subset.
|
|
162
|
+
|
|
163
|
+
Splits based on data_config.train_ratio and data_config.val_ratio.
|
|
164
|
+
"""
|
|
165
|
+
total_samples = self.data_config.total_samples
|
|
166
|
+
if total_samples is None:
|
|
167
|
+
total_samples = len(self.dataset)
|
|
168
|
+
|
|
169
|
+
# Calculate split indices
|
|
170
|
+
train_size = int(total_samples * self.data_config.train_ratio)
|
|
171
|
+
val_size = int(total_samples * self.data_config.val_ratio)
|
|
172
|
+
test_size = total_samples - train_size - val_size
|
|
173
|
+
|
|
174
|
+
# Create subsets
|
|
175
|
+
self.train_dataset = Subset(
|
|
176
|
+
self.dataset,
|
|
177
|
+
list(range(0, train_size))
|
|
178
|
+
)
|
|
179
|
+
self.val_dataset = Subset(
|
|
180
|
+
self.dataset,
|
|
181
|
+
list(range(train_size, train_size + val_size))
|
|
182
|
+
)
|
|
183
|
+
self.test_dataset = Subset(
|
|
184
|
+
self.dataset,
|
|
185
|
+
list(range(train_size + val_size, total_samples))
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _create_dataloaders(self):
|
|
189
|
+
"""
|
|
190
|
+
Create DataLoaders with proper settings.
|
|
191
|
+
|
|
192
|
+
Handles:
|
|
193
|
+
- Distributed training (DistributedSampler)
|
|
194
|
+
- num_workers, pin_memory, etc.
|
|
195
|
+
- Batch size from config
|
|
196
|
+
"""
|
|
197
|
+
batch_size = self.config.training.batch_size
|
|
198
|
+
distributed = self.config.training.distributed_training
|
|
199
|
+
|
|
200
|
+
# DataLoader kwargs
|
|
201
|
+
loader_kwargs = {
|
|
202
|
+
'batch_size': batch_size,
|
|
203
|
+
'num_workers': self.data_config.num_workers,
|
|
204
|
+
'pin_memory': self.data_config.pin_memory,
|
|
205
|
+
'persistent_workers': self.data_config.persistent_workers if self.data_config.num_workers > 0 else False,
|
|
206
|
+
'prefetch_factor': self.data_config.prefetch_factor if self.data_config.num_workers > 0 else None,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
# Train loader (with shuffling)
|
|
210
|
+
if distributed:
|
|
211
|
+
train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
|
|
212
|
+
self.train_loader = DataLoader(
|
|
213
|
+
self.train_dataset,
|
|
214
|
+
sampler=train_sampler,
|
|
215
|
+
shuffle=False, # Don't shuffle when using sampler
|
|
216
|
+
**loader_kwargs
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
self.train_loader = DataLoader(
|
|
220
|
+
self.train_dataset,
|
|
221
|
+
shuffle=True,
|
|
222
|
+
**loader_kwargs
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Val loader (no shuffling)
|
|
226
|
+
if distributed:
|
|
227
|
+
val_sampler = DistributedSampler(self.val_dataset, shuffle=False)
|
|
228
|
+
self.val_loader = DataLoader(
|
|
229
|
+
self.val_dataset,
|
|
230
|
+
sampler=val_sampler,
|
|
231
|
+
shuffle=False,
|
|
232
|
+
**loader_kwargs
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
self.val_loader = DataLoader(
|
|
236
|
+
self.val_dataset,
|
|
237
|
+
shuffle=False,
|
|
238
|
+
**loader_kwargs
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Test loader (no shuffling)
|
|
242
|
+
if distributed:
|
|
243
|
+
test_sampler = DistributedSampler(self.test_dataset, shuffle=False)
|
|
244
|
+
self.test_loader = DataLoader(
|
|
245
|
+
self.test_dataset,
|
|
246
|
+
sampler=test_sampler,
|
|
247
|
+
shuffle=False,
|
|
248
|
+
**loader_kwargs
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
self.test_loader = DataLoader(
|
|
252
|
+
self.test_dataset,
|
|
253
|
+
shuffle=False,
|
|
254
|
+
**loader_kwargs
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def train_dataloader(self) -> DataLoader:
|
|
258
|
+
"""Return training DataLoader."""
|
|
259
|
+
if self.train_loader is None:
|
|
260
|
+
raise RuntimeError("Call setup() before accessing dataloaders")
|
|
261
|
+
return self.train_loader
|
|
262
|
+
|
|
263
|
+
def val_dataloader(self) -> DataLoader:
|
|
264
|
+
"""Return validation DataLoader."""
|
|
265
|
+
if self.val_loader is None:
|
|
266
|
+
raise RuntimeError("Call setup() before accessing dataloaders")
|
|
267
|
+
return self.val_loader
|
|
268
|
+
|
|
269
|
+
def test_dataloader(self) -> DataLoader:
|
|
270
|
+
"""Return test DataLoader."""
|
|
271
|
+
if self.test_loader is None:
|
|
272
|
+
raise RuntimeError("Call setup() before accessing dataloaders")
|
|
273
|
+
return self.test_loader
|
kito/data/datasets.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import h5py
|
|
2
|
+
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
from kito.data.registry import DATASETS
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class KitoDataset(Dataset, ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for all datasets.
|
|
12
|
+
|
|
13
|
+
Provides common data loading and preprocessing pattern.
|
|
14
|
+
Subclasses must implement _load_sample() to define how data is loaded.
|
|
15
|
+
|
|
16
|
+
The standard workflow is:
|
|
17
|
+
1. _load_sample(index) - Load raw data (subclass implements)
|
|
18
|
+
2. Return data and labels
|
|
19
|
+
|
|
20
|
+
Note: Preprocessing is now handled by PreprocessedDataset wrapper,
|
|
21
|
+
so _preprocess_data() is removed from here.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def _load_sample(self, index):
|
|
26
|
+
"""
|
|
27
|
+
Load a single raw sample from the data source.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
index: Sample index
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tuple of (data, labels) as numpy arrays or tensors
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, index):
|
|
38
|
+
"""
|
|
39
|
+
Get a single sample from the dataset.
|
|
40
|
+
|
|
41
|
+
Standard workflow:
|
|
42
|
+
1. Load raw sample
|
|
43
|
+
2. Return (preprocessing happens in PreprocessedDataset wrapper)
|
|
44
|
+
"""
|
|
45
|
+
data, labels = self._load_sample(index)
|
|
46
|
+
return data, labels
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def __len__(self):
|
|
50
|
+
"""Get the total number of samples in the dataset."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@DATASETS.register('h5dataset')
|
|
55
|
+
class H5Dataset(KitoDataset):
|
|
56
|
+
"""
|
|
57
|
+
HDF5 dataset for PyTorch with lazy loading.
|
|
58
|
+
|
|
59
|
+
Loads 'data' and 'labels' from HDF5 file with lazy loading for
|
|
60
|
+
multiprocessing compatibility (DataLoader with num_workers > 0).
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path: Path to HDF5 file
|
|
64
|
+
|
|
65
|
+
HDF5 Structure Expected:
|
|
66
|
+
- 'data': Input data array (N, ...)
|
|
67
|
+
- 'labels': Target labels array (N, ...)
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> dataset = H5Dataset("train.h5")
|
|
71
|
+
>>> data, labels = dataset[0]
|
|
72
|
+
>>>
|
|
73
|
+
>>> # Works with DataLoader
|
|
74
|
+
>>> loader = DataLoader(dataset, batch_size=32, num_workers=4)
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, path: str):
|
|
78
|
+
self.file_path = path
|
|
79
|
+
|
|
80
|
+
# Lazy-loaded attributes (set in _lazy_load)
|
|
81
|
+
self.dataset_data = None
|
|
82
|
+
self.dataset_labels = None
|
|
83
|
+
self.h5file = None
|
|
84
|
+
|
|
85
|
+
# Get dataset length (without lazy loading)
|
|
86
|
+
with h5py.File(self.file_path, 'r') as file:
|
|
87
|
+
self.dataset_len = len(file["data"])
|
|
88
|
+
|
|
89
|
+
def _lazy_load(self):
|
|
90
|
+
"""
|
|
91
|
+
Open HDF5 file and get dataset references.
|
|
92
|
+
Called automatically in _load_sample().
|
|
93
|
+
"""
|
|
94
|
+
if self.dataset_data is None or self.dataset_labels is None:
|
|
95
|
+
try:
|
|
96
|
+
self.h5file = h5py.File(self.file_path, 'r')
|
|
97
|
+
self.dataset_data = self.h5file["data"]
|
|
98
|
+
self.dataset_labels = self.h5file["labels"]
|
|
99
|
+
except (OSError, KeyError) as e:
|
|
100
|
+
raise RuntimeError(f"Failed to load H5 file '{self.file_path}': {e}")
|
|
101
|
+
|
|
102
|
+
def _load_sample(self, index):
|
|
103
|
+
"""Load sample from HDF5 file with lazy loading."""
|
|
104
|
+
self._lazy_load()
|
|
105
|
+
return self.dataset_data[index], self.dataset_labels[index]
|
|
106
|
+
|
|
107
|
+
def __len__(self):
|
|
108
|
+
return self.dataset_len
|
|
109
|
+
|
|
110
|
+
def __del__(self):
|
|
111
|
+
"""Close HDF5 file when object is destroyed."""
|
|
112
|
+
if hasattr(self, 'h5file') and self.h5file is not None:
|
|
113
|
+
self.h5file.close()
|
|
114
|
+
|
|
115
|
+
def __getstate__(self):
|
|
116
|
+
"""
|
|
117
|
+
Prepare object for pickling (needed for DataLoader with num_workers > 0).
|
|
118
|
+
Remove non-picklable HDF5 file handles.
|
|
119
|
+
"""
|
|
120
|
+
state = self.__dict__.copy()
|
|
121
|
+
# Remove HDF5 file handle and dataset references
|
|
122
|
+
state['dataset_data'] = None
|
|
123
|
+
state['dataset_labels'] = None
|
|
124
|
+
if 'h5file' in state:
|
|
125
|
+
del state['h5file']
|
|
126
|
+
return state
|
|
127
|
+
|
|
128
|
+
def __setstate__(self, state):
|
|
129
|
+
"""Restore state after unpickling. File handles will be reloaded lazily."""
|
|
130
|
+
self.__dict__.update(state)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@DATASETS.register('memdataset')
|
|
134
|
+
class MemDataset(KitoDataset):
|
|
135
|
+
"""
|
|
136
|
+
In-memory dataset for PyTorch.
|
|
137
|
+
|
|
138
|
+
Stores data and labels in memory (as numpy arrays or tensors).
|
|
139
|
+
Useful when dataset fits in RAM for faster training.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
x: Input data array (N, ...)
|
|
143
|
+
y: Target labels array (N, ...)
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
>>> import numpy as np
|
|
147
|
+
>>> x = np.random.randn(100, 10, 64, 64, 1)
|
|
148
|
+
>>> y = np.random.randn(100, 10, 64, 64, 1)
|
|
149
|
+
>>> dataset = MemDataset(x, y)
|
|
150
|
+
>>> data, labels = dataset[0]
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(self, x, y):
|
|
154
|
+
self.x = x
|
|
155
|
+
self.y = y
|
|
156
|
+
|
|
157
|
+
# Validate shapes
|
|
158
|
+
if len(x) != len(y):
|
|
159
|
+
raise ValueError(f"x and y must have same length. Got {len(x)} and {len(y)}")
|
|
160
|
+
|
|
161
|
+
def _load_sample(self, index):
|
|
162
|
+
"""Load sample from memory arrays."""
|
|
163
|
+
return self.x[index], self.y[index]
|
|
164
|
+
|
|
165
|
+
def __len__(self):
|
|
166
|
+
return self.x.shape[0]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PreprocessedDataset - Wraps a dataset and applies preprocessing.
|
|
3
|
+
|
|
4
|
+
Separates data loading from preprocessing for flexibility.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
raw_dataset = H5Dataset('data.h5')
|
|
8
|
+
preprocessing = Pipeline([Detrend(), Standardization()])
|
|
9
|
+
dataset = PreprocessedDataset(raw_dataset, preprocessing)
|
|
10
|
+
"""
|
|
11
|
+
from torch.utils.data import Dataset
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PreprocessedDataset(Dataset):
|
|
15
|
+
"""
|
|
16
|
+
Wraps a base dataset and applies preprocessing on-the-fly.
|
|
17
|
+
|
|
18
|
+
This allows:
|
|
19
|
+
- Keeping datasets "dumb" (just load data)
|
|
20
|
+
- Composable preprocessing
|
|
21
|
+
- Easy experimentation (swap preprocessing without changing dataset)
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
base_dataset: Underlying dataset (H5Dataset, MemDataset, etc.)
|
|
25
|
+
preprocessing: Preprocessing instance or None
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> raw_dataset = H5Dataset('data.h5')
|
|
29
|
+
>>> preprocessing = Standardization(mean=0, std=1)
|
|
30
|
+
>>> dataset = PreprocessedDataset(raw_dataset, preprocessing)
|
|
31
|
+
>>> data, labels = dataset[0] # Automatically preprocessed
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, base_dataset: Dataset, preprocessing=None):
|
|
35
|
+
self.base_dataset = base_dataset
|
|
36
|
+
self.preprocessing = preprocessing
|
|
37
|
+
|
|
38
|
+
def __getitem__(self, index):
|
|
39
|
+
# Load raw data
|
|
40
|
+
data, labels = self.base_dataset[index]
|
|
41
|
+
|
|
42
|
+
# Apply preprocessing if specified
|
|
43
|
+
if self.preprocessing is not None:
|
|
44
|
+
data, labels = self.preprocessing(data, labels)
|
|
45
|
+
|
|
46
|
+
return data, labels
|
|
47
|
+
|
|
48
|
+
def __len__(self):
|
|
49
|
+
return len(self.base_dataset)
|
|
50
|
+
|
|
51
|
+
def __repr__(self):
|
|
52
|
+
return (
|
|
53
|
+
f"PreprocessedDataset(\n"
|
|
54
|
+
f" base_dataset={self.base_dataset},\n"
|
|
55
|
+
f" preprocessing={self.preprocessing}\n"
|
|
56
|
+
f")"
|
|
57
|
+
)
|