pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kito/data/__init__.py ADDED
@@ -0,0 +1,35 @@
1
+ # src/kito/data/__init__.py
2
+ """Kito Data - Datasets, preprocessing, and pipelines"""
3
+
4
+ from kito.data.datasets import KitoDataset, H5Dataset, MemDataset
5
+ from kito.data.preprocessing import (
6
+ Preprocessing,
7
+ Pipeline,
8
+ Normalize,
9
+ Standardization,
10
+ ToTensor
11
+ )
12
+ from kito.data.datapipeline import GenericDataPipeline, BaseDataPipeline
13
+ from kito.data.registry import DATASETS, PREPROCESSING
14
+
15
+ __all__ = [
16
+ # Datasets
17
+ "KitoDataset",
18
+ "H5Dataset",
19
+ "MemDataset",
20
+
21
+ # Preprocessing
22
+ "Preprocessing",
23
+ "Pipeline",
24
+ "Normalize",
25
+ "Standardization",
26
+ "ToTensor",
27
+
28
+ # Pipelines
29
+ "GenericDataPipeline",
30
+ "BaseDataPipeline",
31
+
32
+ # Registries
33
+ "DATASETS",
34
+ "PREPROCESSING",
35
+ ]
@@ -0,0 +1,273 @@
1
+ """
2
+ DataPipeline classes - Orchestrate data pipeline.
3
+
4
+ DataPipeline encapsulates:
5
+ - Dataset creation
6
+ - Preprocessing setup
7
+ - Train/val/test splitting
8
+ - DataLoader configuration
9
+
10
+ This separates data concerns from model training logic.
11
+ """
12
+ from abc import ABC, abstractmethod
13
+ from typing import Optional
14
+
15
+ from torch.utils.data import DataLoader, Subset, DistributedSampler
16
+
17
+ from kito.config.moduleconfig import KitoModuleConfig
18
+ from kito.data.datasets import KitoDataset
19
+ from kito.data.preprocessed_dataset import PreprocessedDataset
20
+ from kito.data.preprocessing import Preprocessing
21
+
22
+
23
+ class BaseDataPipeline(ABC):
24
+ """
25
+ Base class for data modules.
26
+
27
+ A DataModule encapsulates all data-related logic:
28
+ - How to load data (dataset)
29
+ - How to preprocess data
30
+ - How to split into train/val/test
31
+ - How to create DataLoaders
32
+
33
+ Usage:
34
+ class MyDataModule(BaseDataModule):
35
+ def setup(self):
36
+ self.dataset = MyDataset(self.data_config.dataset_path)
37
+ # ... setup preprocessing, splits, etc.
38
+
39
+ def train_dataloader(self):
40
+ return DataLoader(self.train_dataset, ...)
41
+
42
+ # In training
43
+ dm = MyDataModule(config)
44
+ dm.setup()
45
+ engine.fit(datamodule=dm)
46
+ """
47
+
48
+ def __init__(self, config: KitoModuleConfig):
49
+ """
50
+ Initialize DataModule.
51
+
52
+ Args:
53
+ config: Full module configuration (contains data config)
54
+ """
55
+ self.config = config
56
+ self.data_config = config.data
57
+
58
+ # Will be set by setup()
59
+ self.dataset = None
60
+ self.train_dataset = None
61
+ self.val_dataset = None
62
+ self.test_dataset = None
63
+
64
+ self.train_loader = None
65
+ self.val_loader = None
66
+ self.test_loader = None
67
+
68
+ @abstractmethod
69
+ def setup(self):
70
+ """
71
+ Setup datasets and preprocessing.
72
+
73
+ This method should:
74
+ 1. Load raw dataset
75
+ 2. Apply preprocessing (if any)
76
+ 3. Create train/val/test splits
77
+ 4. Create DataLoaders
78
+
79
+ Called once before training.
80
+ """
81
+ pass
82
+
83
+ @abstractmethod
84
+ def train_dataloader(self) -> DataLoader:
85
+ """Return training DataLoader."""
86
+ pass
87
+
88
+ @abstractmethod
89
+ def val_dataloader(self) -> DataLoader:
90
+ """Return validation DataLoader."""
91
+ pass
92
+
93
+ @abstractmethod
94
+ def test_dataloader(self) -> DataLoader:
95
+ """Return test DataLoader."""
96
+ pass
97
+
98
+
99
+ class GenericDataPipeline(BaseDataPipeline):
100
+ """
101
+ Generic DataModule that works with any registered dataset.
102
+
103
+ This is the standard DataModule used by DataModuleFactory.
104
+ It handles:
105
+ - Loading dataset from registry
106
+ - Applying preprocessing pipeline
107
+ - Creating train/val/test splits
108
+ - Setting up DataLoaders with DDP support
109
+
110
+ Args:
111
+ config: Full module configuration
112
+ dataset: Pre-instantiated dataset (optional)
113
+ preprocessing: Pre-instantiated preprocessing (optional)
114
+
115
+ Example:
116
+ >>> dm = GenericDataModule(config)
117
+ >>> dm.setup()
118
+ >>> train_loader = dm.train_dataloader()
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ config: KitoModuleConfig,
124
+ dataset: Optional[KitoDataset] = None,
125
+ preprocessing: Optional[Preprocessing] = None
126
+ ):
127
+ super().__init__(config)
128
+ self._dataset = dataset # Pre-instantiated dataset
129
+ self._preprocessing = preprocessing # Pre-instantiated preprocessing
130
+
131
+ def setup(self):
132
+ """
133
+ Setup data pipeline.
134
+
135
+ Steps:
136
+ 1. Use pre-instantiated dataset or load from registry
137
+ 2. Wrap with preprocessing if provided
138
+ 3. Create train/val/test splits
139
+ 4. Create DataLoaders
140
+ """
141
+ # 1. Get dataset
142
+ if self._dataset is not None:
143
+ raw_dataset = self._dataset
144
+ else:
145
+ raise ValueError("Dataset not provided to GenericDataModule")
146
+
147
+ # 2. Apply preprocessing
148
+ if self._preprocessing is not None:
149
+ self.dataset = PreprocessedDataset(raw_dataset, self._preprocessing)
150
+ else:
151
+ self.dataset = raw_dataset
152
+
153
+ # 3. Create splits
154
+ self._create_splits()
155
+
156
+ # 4. Create DataLoaders
157
+ self._create_dataloaders()
158
+
159
+ def _create_splits(self):
160
+ """
161
+ Create train/val/test splits using Subset.
162
+
163
+ Splits based on data_config.train_ratio and data_config.val_ratio.
164
+ """
165
+ total_samples = self.data_config.total_samples
166
+ if total_samples is None:
167
+ total_samples = len(self.dataset)
168
+
169
+ # Calculate split indices
170
+ train_size = int(total_samples * self.data_config.train_ratio)
171
+ val_size = int(total_samples * self.data_config.val_ratio)
172
+ test_size = total_samples - train_size - val_size
173
+
174
+ # Create subsets
175
+ self.train_dataset = Subset(
176
+ self.dataset,
177
+ list(range(0, train_size))
178
+ )
179
+ self.val_dataset = Subset(
180
+ self.dataset,
181
+ list(range(train_size, train_size + val_size))
182
+ )
183
+ self.test_dataset = Subset(
184
+ self.dataset,
185
+ list(range(train_size + val_size, total_samples))
186
+ )
187
+
188
+ def _create_dataloaders(self):
189
+ """
190
+ Create DataLoaders with proper settings.
191
+
192
+ Handles:
193
+ - Distributed training (DistributedSampler)
194
+ - num_workers, pin_memory, etc.
195
+ - Batch size from config
196
+ """
197
+ batch_size = self.config.training.batch_size
198
+ distributed = self.config.training.distributed_training
199
+
200
+ # DataLoader kwargs
201
+ loader_kwargs = {
202
+ 'batch_size': batch_size,
203
+ 'num_workers': self.data_config.num_workers,
204
+ 'pin_memory': self.data_config.pin_memory,
205
+ 'persistent_workers': self.data_config.persistent_workers if self.data_config.num_workers > 0 else False,
206
+ 'prefetch_factor': self.data_config.prefetch_factor if self.data_config.num_workers > 0 else None,
207
+ }
208
+
209
+ # Train loader (with shuffling)
210
+ if distributed:
211
+ train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
212
+ self.train_loader = DataLoader(
213
+ self.train_dataset,
214
+ sampler=train_sampler,
215
+ shuffle=False, # Don't shuffle when using sampler
216
+ **loader_kwargs
217
+ )
218
+ else:
219
+ self.train_loader = DataLoader(
220
+ self.train_dataset,
221
+ shuffle=True,
222
+ **loader_kwargs
223
+ )
224
+
225
+ # Val loader (no shuffling)
226
+ if distributed:
227
+ val_sampler = DistributedSampler(self.val_dataset, shuffle=False)
228
+ self.val_loader = DataLoader(
229
+ self.val_dataset,
230
+ sampler=val_sampler,
231
+ shuffle=False,
232
+ **loader_kwargs
233
+ )
234
+ else:
235
+ self.val_loader = DataLoader(
236
+ self.val_dataset,
237
+ shuffle=False,
238
+ **loader_kwargs
239
+ )
240
+
241
+ # Test loader (no shuffling)
242
+ if distributed:
243
+ test_sampler = DistributedSampler(self.test_dataset, shuffle=False)
244
+ self.test_loader = DataLoader(
245
+ self.test_dataset,
246
+ sampler=test_sampler,
247
+ shuffle=False,
248
+ **loader_kwargs
249
+ )
250
+ else:
251
+ self.test_loader = DataLoader(
252
+ self.test_dataset,
253
+ shuffle=False,
254
+ **loader_kwargs
255
+ )
256
+
257
+ def train_dataloader(self) -> DataLoader:
258
+ """Return training DataLoader."""
259
+ if self.train_loader is None:
260
+ raise RuntimeError("Call setup() before accessing dataloaders")
261
+ return self.train_loader
262
+
263
+ def val_dataloader(self) -> DataLoader:
264
+ """Return validation DataLoader."""
265
+ if self.val_loader is None:
266
+ raise RuntimeError("Call setup() before accessing dataloaders")
267
+ return self.val_loader
268
+
269
+ def test_dataloader(self) -> DataLoader:
270
+ """Return test DataLoader."""
271
+ if self.test_loader is None:
272
+ raise RuntimeError("Call setup() before accessing dataloaders")
273
+ return self.test_loader
kito/data/datasets.py ADDED
@@ -0,0 +1,166 @@
1
+ import h5py
2
+
3
+ from torch.utils.data import Dataset
4
+ from abc import ABC, abstractmethod
5
+
6
+ from kito.data.registry import DATASETS
7
+
8
+
9
+ class KitoDataset(Dataset, ABC):
10
+ """
11
+ Abstract base class for all datasets.
12
+
13
+ Provides common data loading and preprocessing pattern.
14
+ Subclasses must implement _load_sample() to define how data is loaded.
15
+
16
+ The standard workflow is:
17
+ 1. _load_sample(index) - Load raw data (subclass implements)
18
+ 2. Return data and labels
19
+
20
+ Note: Preprocessing is now handled by PreprocessedDataset wrapper,
21
+ so _preprocess_data() is removed from here.
22
+ """
23
+
24
+ @abstractmethod
25
+ def _load_sample(self, index):
26
+ """
27
+ Load a single raw sample from the data source.
28
+
29
+ Args:
30
+ index: Sample index
31
+
32
+ Returns:
33
+ Tuple of (data, labels) as numpy arrays or tensors
34
+ """
35
+ pass
36
+
37
+ def __getitem__(self, index):
38
+ """
39
+ Get a single sample from the dataset.
40
+
41
+ Standard workflow:
42
+ 1. Load raw sample
43
+ 2. Return (preprocessing happens in PreprocessedDataset wrapper)
44
+ """
45
+ data, labels = self._load_sample(index)
46
+ return data, labels
47
+
48
+ @abstractmethod
49
+ def __len__(self):
50
+ """Get the total number of samples in the dataset."""
51
+ pass
52
+
53
+
54
+ @DATASETS.register('h5dataset')
55
+ class H5Dataset(KitoDataset):
56
+ """
57
+ HDF5 dataset for PyTorch with lazy loading.
58
+
59
+ Loads 'data' and 'labels' from HDF5 file with lazy loading for
60
+ multiprocessing compatibility (DataLoader with num_workers > 0).
61
+
62
+ Args:
63
+ path: Path to HDF5 file
64
+
65
+ HDF5 Structure Expected:
66
+ - 'data': Input data array (N, ...)
67
+ - 'labels': Target labels array (N, ...)
68
+
69
+ Example:
70
+ >>> dataset = H5Dataset("train.h5")
71
+ >>> data, labels = dataset[0]
72
+ >>>
73
+ >>> # Works with DataLoader
74
+ >>> loader = DataLoader(dataset, batch_size=32, num_workers=4)
75
+ """
76
+
77
+ def __init__(self, path: str):
78
+ self.file_path = path
79
+
80
+ # Lazy-loaded attributes (set in _lazy_load)
81
+ self.dataset_data = None
82
+ self.dataset_labels = None
83
+ self.h5file = None
84
+
85
+ # Get dataset length (without lazy loading)
86
+ with h5py.File(self.file_path, 'r') as file:
87
+ self.dataset_len = len(file["data"])
88
+
89
+ def _lazy_load(self):
90
+ """
91
+ Open HDF5 file and get dataset references.
92
+ Called automatically in _load_sample().
93
+ """
94
+ if self.dataset_data is None or self.dataset_labels is None:
95
+ try:
96
+ self.h5file = h5py.File(self.file_path, 'r')
97
+ self.dataset_data = self.h5file["data"]
98
+ self.dataset_labels = self.h5file["labels"]
99
+ except (OSError, KeyError) as e:
100
+ raise RuntimeError(f"Failed to load H5 file '{self.file_path}': {e}")
101
+
102
+ def _load_sample(self, index):
103
+ """Load sample from HDF5 file with lazy loading."""
104
+ self._lazy_load()
105
+ return self.dataset_data[index], self.dataset_labels[index]
106
+
107
+ def __len__(self):
108
+ return self.dataset_len
109
+
110
+ def __del__(self):
111
+ """Close HDF5 file when object is destroyed."""
112
+ if hasattr(self, 'h5file') and self.h5file is not None:
113
+ self.h5file.close()
114
+
115
+ def __getstate__(self):
116
+ """
117
+ Prepare object for pickling (needed for DataLoader with num_workers > 0).
118
+ Remove non-picklable HDF5 file handles.
119
+ """
120
+ state = self.__dict__.copy()
121
+ # Remove HDF5 file handle and dataset references
122
+ state['dataset_data'] = None
123
+ state['dataset_labels'] = None
124
+ if 'h5file' in state:
125
+ del state['h5file']
126
+ return state
127
+
128
+ def __setstate__(self, state):
129
+ """Restore state after unpickling. File handles will be reloaded lazily."""
130
+ self.__dict__.update(state)
131
+
132
+
133
+ @DATASETS.register('memdataset')
134
+ class MemDataset(KitoDataset):
135
+ """
136
+ In-memory dataset for PyTorch.
137
+
138
+ Stores data and labels in memory (as numpy arrays or tensors).
139
+ Useful when dataset fits in RAM for faster training.
140
+
141
+ Args:
142
+ x: Input data array (N, ...)
143
+ y: Target labels array (N, ...)
144
+
145
+ Example:
146
+ >>> import numpy as np
147
+ >>> x = np.random.randn(100, 10, 64, 64, 1)
148
+ >>> y = np.random.randn(100, 10, 64, 64, 1)
149
+ >>> dataset = MemDataset(x, y)
150
+ >>> data, labels = dataset[0]
151
+ """
152
+
153
+ def __init__(self, x, y):
154
+ self.x = x
155
+ self.y = y
156
+
157
+ # Validate shapes
158
+ if len(x) != len(y):
159
+ raise ValueError(f"x and y must have same length. Got {len(x)} and {len(y)}")
160
+
161
+ def _load_sample(self, index):
162
+ """Load sample from memory arrays."""
163
+ return self.x[index], self.y[index]
164
+
165
+ def __len__(self):
166
+ return self.x.shape[0]
@@ -0,0 +1,57 @@
1
+ """
2
+ PreprocessedDataset - Wraps a dataset and applies preprocessing.
3
+
4
+ Separates data loading from preprocessing for flexibility.
5
+
6
+ Example:
7
+ raw_dataset = H5Dataset('data.h5')
8
+ preprocessing = Pipeline([Detrend(), Standardization()])
9
+ dataset = PreprocessedDataset(raw_dataset, preprocessing)
10
+ """
11
+ from torch.utils.data import Dataset
12
+
13
+
14
+ class PreprocessedDataset(Dataset):
15
+ """
16
+ Wraps a base dataset and applies preprocessing on-the-fly.
17
+
18
+ This allows:
19
+ - Keeping datasets "dumb" (just load data)
20
+ - Composable preprocessing
21
+ - Easy experimentation (swap preprocessing without changing dataset)
22
+
23
+ Args:
24
+ base_dataset: Underlying dataset (H5Dataset, MemDataset, etc.)
25
+ preprocessing: Preprocessing instance or None
26
+
27
+ Example:
28
+ >>> raw_dataset = H5Dataset('data.h5')
29
+ >>> preprocessing = Standardization(mean=0, std=1)
30
+ >>> dataset = PreprocessedDataset(raw_dataset, preprocessing)
31
+ >>> data, labels = dataset[0] # Automatically preprocessed
32
+ """
33
+
34
+ def __init__(self, base_dataset: Dataset, preprocessing=None):
35
+ self.base_dataset = base_dataset
36
+ self.preprocessing = preprocessing
37
+
38
+ def __getitem__(self, index):
39
+ # Load raw data
40
+ data, labels = self.base_dataset[index]
41
+
42
+ # Apply preprocessing if specified
43
+ if self.preprocessing is not None:
44
+ data, labels = self.preprocessing(data, labels)
45
+
46
+ return data, labels
47
+
48
+ def __len__(self):
49
+ return len(self.base_dataset)
50
+
51
+ def __repr__(self):
52
+ return (
53
+ f"PreprocessedDataset(\n"
54
+ f" base_dataset={self.base_dataset},\n"
55
+ f" preprocessing={self.preprocessing}\n"
56
+ f")"
57
+ )