pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Preprocessing classes for data transformation.
|
|
3
|
+
|
|
4
|
+
All preprocessing classes inherit from Preprocessing base class
|
|
5
|
+
and implement __call__ method.
|
|
6
|
+
|
|
7
|
+
Preprocessing can be:
|
|
8
|
+
- Composed using Pipeline
|
|
9
|
+
- Configured via config files
|
|
10
|
+
- Registered for factory instantiation
|
|
11
|
+
"""
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Tuple, List
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from kito.data.registry import PREPROCESSING
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Preprocessing(ABC):
|
|
22
|
+
"""
|
|
23
|
+
Base class for all preprocessing operations.
|
|
24
|
+
|
|
25
|
+
Preprocessing transforms (data, labels) → (processed_data, processed_labels)
|
|
26
|
+
|
|
27
|
+
Subclasses must implement __call__ method.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
class MyPreprocessing(Preprocessing):
|
|
31
|
+
def __call__(self, data, labels):
|
|
32
|
+
# Transform data
|
|
33
|
+
return processed_data, labels
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def __call__(self, data, labels) -> Tuple:
|
|
38
|
+
"""
|
|
39
|
+
Apply preprocessing.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
data: Input data (numpy array or tensor)
|
|
43
|
+
labels: Target labels (numpy array or tensor)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tuple of (processed_data, processed_labels)
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@PREPROCESSING.register('pipeline')
|
|
52
|
+
class Pipeline(Preprocessing):
|
|
53
|
+
"""
|
|
54
|
+
Chain multiple preprocessing steps.
|
|
55
|
+
|
|
56
|
+
Applies preprocessing steps sequentially.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
steps: List of Preprocessing instances
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> pipeline = Pipeline([
|
|
63
|
+
... Detrend(),
|
|
64
|
+
... Standardization(mean=0.5, std=0.2)
|
|
65
|
+
... ])
|
|
66
|
+
>>> data, labels = pipeline(data, labels)
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, steps: List[Preprocessing]):
|
|
70
|
+
self.steps = steps
|
|
71
|
+
|
|
72
|
+
def __call__(self, data, labels):
|
|
73
|
+
for step in self.steps:
|
|
74
|
+
data, labels = step(data, labels)
|
|
75
|
+
return data, labels
|
|
76
|
+
|
|
77
|
+
def __repr__(self):
|
|
78
|
+
steps_repr = ', '.join([step.__class__.__name__ for step in self.steps])
|
|
79
|
+
return f"Pipeline([{steps_repr}])"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@PREPROCESSING.register('normalize')
|
|
83
|
+
class Normalize(Preprocessing):
|
|
84
|
+
"""
|
|
85
|
+
Min-max normalization: scale data to [min_val, max_val].
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
min_val: Minimum value after normalization
|
|
89
|
+
max_val: Maximum value after normalization
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> norm = Normalize(min_val=0, max_val=1)
|
|
93
|
+
>>> data, labels = norm(data, labels)
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, min_val: float = 0.0, max_val: float = 1.0):
|
|
97
|
+
self.min_val = min_val
|
|
98
|
+
self.max_val = max_val
|
|
99
|
+
|
|
100
|
+
def __call__(self, data, labels):
|
|
101
|
+
data_min = data.min()
|
|
102
|
+
data_max = data.max()
|
|
103
|
+
|
|
104
|
+
if data_max - data_min > 0:
|
|
105
|
+
data = (data - data_min) / (data_max - data_min)
|
|
106
|
+
data = data * (self.max_val - self.min_val) + self.min_val
|
|
107
|
+
|
|
108
|
+
return data, labels
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@PREPROCESSING.register('standardization')
|
|
112
|
+
class Standardization(Preprocessing):
|
|
113
|
+
"""
|
|
114
|
+
Standardize data: (data - mean) / std.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
mean: Mean for standardization (None = compute from data)
|
|
118
|
+
std: Standard deviation (None = compute from data)
|
|
119
|
+
eps: Small constant to avoid division by zero
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> # Compute mean/std from data
|
|
123
|
+
>>> std = Standardization()
|
|
124
|
+
>>> data, labels = std(data, labels)
|
|
125
|
+
|
|
126
|
+
>>> # Use fixed mean/std
|
|
127
|
+
>>> std = Standardization(mean=0.5, std=0.2)
|
|
128
|
+
>>> data, labels = std(data, labels)
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, mean: float = None, std: float = None, eps: float = 1e-8):
|
|
132
|
+
self.mean = mean
|
|
133
|
+
self.std = std
|
|
134
|
+
self.eps = eps
|
|
135
|
+
self._fitted = False
|
|
136
|
+
|
|
137
|
+
def __call__(self, data, labels):
|
|
138
|
+
# Compute mean/std on first call if not provided
|
|
139
|
+
'''if self.mean is None and not self._fitted:
|
|
140
|
+
self.mean = float(data.mean())
|
|
141
|
+
self._fitted = True
|
|
142
|
+
|
|
143
|
+
if self.std is None and not self._fitted:
|
|
144
|
+
self.std = float(data.std())
|
|
145
|
+
self._fitted = True'''
|
|
146
|
+
if self.mean is None:
|
|
147
|
+
if not self._fitted:
|
|
148
|
+
self.mean = float(data.mean())
|
|
149
|
+
if self.std is None:
|
|
150
|
+
if not self._fitted:
|
|
151
|
+
self.std = float(data.std())
|
|
152
|
+
|
|
153
|
+
if not self._fitted:
|
|
154
|
+
self._fitted = True
|
|
155
|
+
|
|
156
|
+
# Standardize
|
|
157
|
+
data = (data - self.mean) / (self.std + self.eps)
|
|
158
|
+
|
|
159
|
+
return data, labels
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@PREPROCESSING.register('clip_outliers')
|
|
163
|
+
class ClipOutliers(Preprocessing):
|
|
164
|
+
"""
|
|
165
|
+
Clip outliers beyond n standard deviations.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
n_std: Number of standard deviations for clipping
|
|
169
|
+
|
|
170
|
+
Example:
|
|
171
|
+
>>> clip = ClipOutliers(n_std=3)
|
|
172
|
+
>>> data, labels = clip(data, labels)
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(self, n_std: float = 3.0):
|
|
176
|
+
self.n_std = n_std
|
|
177
|
+
|
|
178
|
+
def __call__(self, data, labels):
|
|
179
|
+
mean = data.mean()
|
|
180
|
+
std = data.std()
|
|
181
|
+
|
|
182
|
+
lower = mean - self.n_std * std
|
|
183
|
+
upper = mean + self.n_std * std
|
|
184
|
+
|
|
185
|
+
if isinstance(data, torch.Tensor):
|
|
186
|
+
data = torch.clamp(data, lower, upper)
|
|
187
|
+
else:
|
|
188
|
+
data = np.clip(data, lower, upper)
|
|
189
|
+
|
|
190
|
+
return data, labels
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@PREPROCESSING.register('detrend')
|
|
194
|
+
class Detrend(Preprocessing):
|
|
195
|
+
"""
|
|
196
|
+
Remove linear trend from data.
|
|
197
|
+
|
|
198
|
+
Subtracts best-fit plane from each spatial slice.
|
|
199
|
+
Useful for InSAR data with atmospheric gradients.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
axis: Axis along which to detrend (None = all spatial axes)
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> detrend = Detrend()
|
|
206
|
+
>>> data, labels = detrend(data, labels)
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, axis: int = None):
|
|
210
|
+
self.axis = axis
|
|
211
|
+
|
|
212
|
+
def __call__(self, data, labels):
|
|
213
|
+
# Simple linear detrend (subtract mean along axis)
|
|
214
|
+
# For more sophisticated detrending, override this method
|
|
215
|
+
|
|
216
|
+
if self.axis is not None:
|
|
217
|
+
mean = data.mean(axis=self.axis, keepdims=True)
|
|
218
|
+
else:
|
|
219
|
+
mean = data.mean()
|
|
220
|
+
|
|
221
|
+
data = data - mean
|
|
222
|
+
|
|
223
|
+
return data, labels
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@PREPROCESSING.register('add_noise')
|
|
227
|
+
class AddNoise(Preprocessing):
|
|
228
|
+
"""
|
|
229
|
+
Add Gaussian noise to data (for data augmentation).
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
std: Standard deviation of noise
|
|
233
|
+
mean: Mean of noise
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
>>> noise = AddNoise(std=0.01)
|
|
237
|
+
>>> data, labels = noise(data, labels)
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, std: float = 0.01, mean: float = 0.0):
|
|
241
|
+
self.std = std
|
|
242
|
+
self.mean = mean
|
|
243
|
+
|
|
244
|
+
def __call__(self, data, labels):
|
|
245
|
+
if isinstance(data, torch.Tensor):
|
|
246
|
+
noise = torch.randn_like(data) * self.std + self.mean
|
|
247
|
+
else:
|
|
248
|
+
noise = np.random.randn(*data.shape) * self.std + self.mean
|
|
249
|
+
|
|
250
|
+
data = data + noise
|
|
251
|
+
|
|
252
|
+
return data, labels
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@PREPROCESSING.register('log_transform')
|
|
256
|
+
class LogTransform(Preprocessing):
|
|
257
|
+
"""
|
|
258
|
+
Apply log transform: log(data + offset).
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
offset: Offset to ensure positivity
|
|
262
|
+
base: Logarithm base (e, 10, 2)
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
>>> log = LogTransform(offset=1.0, base='e')
|
|
266
|
+
>>> data, labels = log(data, labels)
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
def __init__(self, offset: float = 1.0, base: str = 'e'):
|
|
270
|
+
self.offset = offset
|
|
271
|
+
self.base = base
|
|
272
|
+
|
|
273
|
+
def __call__(self, data, labels):
|
|
274
|
+
data = data + self.offset
|
|
275
|
+
|
|
276
|
+
if self.base == 'e':
|
|
277
|
+
if isinstance(data, torch.Tensor):
|
|
278
|
+
data = torch.log(data)
|
|
279
|
+
else:
|
|
280
|
+
data = np.log(data)
|
|
281
|
+
elif self.base == '10':
|
|
282
|
+
if isinstance(data, torch.Tensor):
|
|
283
|
+
data = torch.log10(data)
|
|
284
|
+
else:
|
|
285
|
+
data = np.log10(data)
|
|
286
|
+
elif self.base == '2':
|
|
287
|
+
if isinstance(data, torch.Tensor):
|
|
288
|
+
data = torch.log2(data)
|
|
289
|
+
else:
|
|
290
|
+
data = np.log2(data)
|
|
291
|
+
|
|
292
|
+
return data, labels
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@PREPROCESSING.register('to_tensor')
|
|
296
|
+
class ToTensor(Preprocessing):
|
|
297
|
+
"""
|
|
298
|
+
Convert numpy arrays to PyTorch tensors.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
dtype: Target dtype (e.g., torch.float32)
|
|
302
|
+
|
|
303
|
+
Example:
|
|
304
|
+
>>> to_tensor = ToTensor(dtype=torch.float32)
|
|
305
|
+
>>> data, labels = to_tensor(data, labels)
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(self, dtype=torch.float32):
|
|
309
|
+
self.dtype = dtype
|
|
310
|
+
|
|
311
|
+
def __call__(self, data, labels):
|
|
312
|
+
if not isinstance(data, torch.Tensor):
|
|
313
|
+
data = torch.from_numpy(data).to(self.dtype)
|
|
314
|
+
|
|
315
|
+
if not isinstance(labels, torch.Tensor):
|
|
316
|
+
labels = torch.from_numpy(labels).to(self.dtype)
|
|
317
|
+
|
|
318
|
+
return data, labels
|
kito/data/registry.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Registry system for datasets and preprocessing.
|
|
3
|
+
|
|
4
|
+
Allows declarative configuration by registering classes with string names.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
@DATASETS.register('h5dataset')
|
|
8
|
+
class H5Dataset:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
# Later
|
|
12
|
+
dataset_cls = DATASETS.get('h5dataset')
|
|
13
|
+
dataset = dataset_cls(path='data.h5')
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Registry:
|
|
18
|
+
"""
|
|
19
|
+
Simple registry for mapping string names to classes.
|
|
20
|
+
|
|
21
|
+
Used for:
|
|
22
|
+
- Dataset types ('h5dataset', 'memdataset')
|
|
23
|
+
- Preprocessing types ('detrend', 'standardization')
|
|
24
|
+
|
|
25
|
+
This enables config-based instantiation.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, name: str):
|
|
29
|
+
"""
|
|
30
|
+
Initialize registry.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
name: Registry name (for error messages)
|
|
34
|
+
"""
|
|
35
|
+
self.name = name
|
|
36
|
+
self._registry = {}
|
|
37
|
+
|
|
38
|
+
def register(self, name: str):
|
|
39
|
+
"""
|
|
40
|
+
Decorator to register a class.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
name: String identifier for the class
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Decorator function
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> @DATASETS.register('h5dataset')
|
|
50
|
+
>>> class H5Dataset:
|
|
51
|
+
... pass
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def decorator(cls):
|
|
55
|
+
if name in self._registry:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"'{name}' already registered in {self.name}. "
|
|
58
|
+
f"Existing: {self._registry[name]}, New: {cls}"
|
|
59
|
+
)
|
|
60
|
+
self._registry[name] = cls
|
|
61
|
+
return cls
|
|
62
|
+
|
|
63
|
+
return decorator
|
|
64
|
+
|
|
65
|
+
def get(self, name: str):
|
|
66
|
+
"""
|
|
67
|
+
Get a registered class by name.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
name: String identifier
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Registered class
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
KeyError: If name not registered
|
|
77
|
+
"""
|
|
78
|
+
if name not in self._registry:
|
|
79
|
+
raise KeyError(
|
|
80
|
+
f"'{name}' not found in {self.name}. "
|
|
81
|
+
f"Available: {list(self._registry.keys())}"
|
|
82
|
+
)
|
|
83
|
+
return self._registry[name]
|
|
84
|
+
|
|
85
|
+
def list_registered(self):
|
|
86
|
+
"""List all registered names."""
|
|
87
|
+
return list(self._registry.keys())
|
|
88
|
+
|
|
89
|
+
def __contains__(self, name: str):
|
|
90
|
+
"""Check if name is registered."""
|
|
91
|
+
return name in self._registry
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Global registries
|
|
95
|
+
DATASETS = Registry('DATASETS')
|
|
96
|
+
PREPROCESSING = Registry('PREPROCESSING')
|