pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,318 @@
1
+ """
2
+ Preprocessing classes for data transformation.
3
+
4
+ All preprocessing classes inherit from Preprocessing base class
5
+ and implement __call__ method.
6
+
7
+ Preprocessing can be:
8
+ - Composed using Pipeline
9
+ - Configured via config files
10
+ - Registered for factory instantiation
11
+ """
12
+ from abc import ABC, abstractmethod
13
+ from typing import Tuple, List
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from kito.data.registry import PREPROCESSING
19
+
20
+
21
+ class Preprocessing(ABC):
22
+ """
23
+ Base class for all preprocessing operations.
24
+
25
+ Preprocessing transforms (data, labels) → (processed_data, processed_labels)
26
+
27
+ Subclasses must implement __call__ method.
28
+
29
+ Example:
30
+ class MyPreprocessing(Preprocessing):
31
+ def __call__(self, data, labels):
32
+ # Transform data
33
+ return processed_data, labels
34
+ """
35
+
36
+ @abstractmethod
37
+ def __call__(self, data, labels) -> Tuple:
38
+ """
39
+ Apply preprocessing.
40
+
41
+ Args:
42
+ data: Input data (numpy array or tensor)
43
+ labels: Target labels (numpy array or tensor)
44
+
45
+ Returns:
46
+ Tuple of (processed_data, processed_labels)
47
+ """
48
+ pass
49
+
50
+
51
+ @PREPROCESSING.register('pipeline')
52
+ class Pipeline(Preprocessing):
53
+ """
54
+ Chain multiple preprocessing steps.
55
+
56
+ Applies preprocessing steps sequentially.
57
+
58
+ Args:
59
+ steps: List of Preprocessing instances
60
+
61
+ Example:
62
+ >>> pipeline = Pipeline([
63
+ ... Detrend(),
64
+ ... Standardization(mean=0.5, std=0.2)
65
+ ... ])
66
+ >>> data, labels = pipeline(data, labels)
67
+ """
68
+
69
+ def __init__(self, steps: List[Preprocessing]):
70
+ self.steps = steps
71
+
72
+ def __call__(self, data, labels):
73
+ for step in self.steps:
74
+ data, labels = step(data, labels)
75
+ return data, labels
76
+
77
+ def __repr__(self):
78
+ steps_repr = ', '.join([step.__class__.__name__ for step in self.steps])
79
+ return f"Pipeline([{steps_repr}])"
80
+
81
+
82
+ @PREPROCESSING.register('normalize')
83
+ class Normalize(Preprocessing):
84
+ """
85
+ Min-max normalization: scale data to [min_val, max_val].
86
+
87
+ Args:
88
+ min_val: Minimum value after normalization
89
+ max_val: Maximum value after normalization
90
+
91
+ Example:
92
+ >>> norm = Normalize(min_val=0, max_val=1)
93
+ >>> data, labels = norm(data, labels)
94
+ """
95
+
96
+ def __init__(self, min_val: float = 0.0, max_val: float = 1.0):
97
+ self.min_val = min_val
98
+ self.max_val = max_val
99
+
100
+ def __call__(self, data, labels):
101
+ data_min = data.min()
102
+ data_max = data.max()
103
+
104
+ if data_max - data_min > 0:
105
+ data = (data - data_min) / (data_max - data_min)
106
+ data = data * (self.max_val - self.min_val) + self.min_val
107
+
108
+ return data, labels
109
+
110
+
111
+ @PREPROCESSING.register('standardization')
112
+ class Standardization(Preprocessing):
113
+ """
114
+ Standardize data: (data - mean) / std.
115
+
116
+ Args:
117
+ mean: Mean for standardization (None = compute from data)
118
+ std: Standard deviation (None = compute from data)
119
+ eps: Small constant to avoid division by zero
120
+
121
+ Example:
122
+ >>> # Compute mean/std from data
123
+ >>> std = Standardization()
124
+ >>> data, labels = std(data, labels)
125
+
126
+ >>> # Use fixed mean/std
127
+ >>> std = Standardization(mean=0.5, std=0.2)
128
+ >>> data, labels = std(data, labels)
129
+ """
130
+
131
+ def __init__(self, mean: float = None, std: float = None, eps: float = 1e-8):
132
+ self.mean = mean
133
+ self.std = std
134
+ self.eps = eps
135
+ self._fitted = False
136
+
137
+ def __call__(self, data, labels):
138
+ # Compute mean/std on first call if not provided
139
+ '''if self.mean is None and not self._fitted:
140
+ self.mean = float(data.mean())
141
+ self._fitted = True
142
+
143
+ if self.std is None and not self._fitted:
144
+ self.std = float(data.std())
145
+ self._fitted = True'''
146
+ if self.mean is None:
147
+ if not self._fitted:
148
+ self.mean = float(data.mean())
149
+ if self.std is None:
150
+ if not self._fitted:
151
+ self.std = float(data.std())
152
+
153
+ if not self._fitted:
154
+ self._fitted = True
155
+
156
+ # Standardize
157
+ data = (data - self.mean) / (self.std + self.eps)
158
+
159
+ return data, labels
160
+
161
+
162
+ @PREPROCESSING.register('clip_outliers')
163
+ class ClipOutliers(Preprocessing):
164
+ """
165
+ Clip outliers beyond n standard deviations.
166
+
167
+ Args:
168
+ n_std: Number of standard deviations for clipping
169
+
170
+ Example:
171
+ >>> clip = ClipOutliers(n_std=3)
172
+ >>> data, labels = clip(data, labels)
173
+ """
174
+
175
+ def __init__(self, n_std: float = 3.0):
176
+ self.n_std = n_std
177
+
178
+ def __call__(self, data, labels):
179
+ mean = data.mean()
180
+ std = data.std()
181
+
182
+ lower = mean - self.n_std * std
183
+ upper = mean + self.n_std * std
184
+
185
+ if isinstance(data, torch.Tensor):
186
+ data = torch.clamp(data, lower, upper)
187
+ else:
188
+ data = np.clip(data, lower, upper)
189
+
190
+ return data, labels
191
+
192
+
193
+ @PREPROCESSING.register('detrend')
194
+ class Detrend(Preprocessing):
195
+ """
196
+ Remove linear trend from data.
197
+
198
+ Subtracts best-fit plane from each spatial slice.
199
+ Useful for InSAR data with atmospheric gradients.
200
+
201
+ Args:
202
+ axis: Axis along which to detrend (None = all spatial axes)
203
+
204
+ Example:
205
+ >>> detrend = Detrend()
206
+ >>> data, labels = detrend(data, labels)
207
+ """
208
+
209
+ def __init__(self, axis: int = None):
210
+ self.axis = axis
211
+
212
+ def __call__(self, data, labels):
213
+ # Simple linear detrend (subtract mean along axis)
214
+ # For more sophisticated detrending, override this method
215
+
216
+ if self.axis is not None:
217
+ mean = data.mean(axis=self.axis, keepdims=True)
218
+ else:
219
+ mean = data.mean()
220
+
221
+ data = data - mean
222
+
223
+ return data, labels
224
+
225
+
226
+ @PREPROCESSING.register('add_noise')
227
+ class AddNoise(Preprocessing):
228
+ """
229
+ Add Gaussian noise to data (for data augmentation).
230
+
231
+ Args:
232
+ std: Standard deviation of noise
233
+ mean: Mean of noise
234
+
235
+ Example:
236
+ >>> noise = AddNoise(std=0.01)
237
+ >>> data, labels = noise(data, labels)
238
+ """
239
+
240
+ def __init__(self, std: float = 0.01, mean: float = 0.0):
241
+ self.std = std
242
+ self.mean = mean
243
+
244
+ def __call__(self, data, labels):
245
+ if isinstance(data, torch.Tensor):
246
+ noise = torch.randn_like(data) * self.std + self.mean
247
+ else:
248
+ noise = np.random.randn(*data.shape) * self.std + self.mean
249
+
250
+ data = data + noise
251
+
252
+ return data, labels
253
+
254
+
255
+ @PREPROCESSING.register('log_transform')
256
+ class LogTransform(Preprocessing):
257
+ """
258
+ Apply log transform: log(data + offset).
259
+
260
+ Args:
261
+ offset: Offset to ensure positivity
262
+ base: Logarithm base (e, 10, 2)
263
+
264
+ Example:
265
+ >>> log = LogTransform(offset=1.0, base='e')
266
+ >>> data, labels = log(data, labels)
267
+ """
268
+
269
+ def __init__(self, offset: float = 1.0, base: str = 'e'):
270
+ self.offset = offset
271
+ self.base = base
272
+
273
+ def __call__(self, data, labels):
274
+ data = data + self.offset
275
+
276
+ if self.base == 'e':
277
+ if isinstance(data, torch.Tensor):
278
+ data = torch.log(data)
279
+ else:
280
+ data = np.log(data)
281
+ elif self.base == '10':
282
+ if isinstance(data, torch.Tensor):
283
+ data = torch.log10(data)
284
+ else:
285
+ data = np.log10(data)
286
+ elif self.base == '2':
287
+ if isinstance(data, torch.Tensor):
288
+ data = torch.log2(data)
289
+ else:
290
+ data = np.log2(data)
291
+
292
+ return data, labels
293
+
294
+
295
+ @PREPROCESSING.register('to_tensor')
296
+ class ToTensor(Preprocessing):
297
+ """
298
+ Convert numpy arrays to PyTorch tensors.
299
+
300
+ Args:
301
+ dtype: Target dtype (e.g., torch.float32)
302
+
303
+ Example:
304
+ >>> to_tensor = ToTensor(dtype=torch.float32)
305
+ >>> data, labels = to_tensor(data, labels)
306
+ """
307
+
308
+ def __init__(self, dtype=torch.float32):
309
+ self.dtype = dtype
310
+
311
+ def __call__(self, data, labels):
312
+ if not isinstance(data, torch.Tensor):
313
+ data = torch.from_numpy(data).to(self.dtype)
314
+
315
+ if not isinstance(labels, torch.Tensor):
316
+ labels = torch.from_numpy(labels).to(self.dtype)
317
+
318
+ return data, labels
kito/data/registry.py ADDED
@@ -0,0 +1,96 @@
1
+ """
2
+ Registry system for datasets and preprocessing.
3
+
4
+ Allows declarative configuration by registering classes with string names.
5
+
6
+ Example:
7
+ @DATASETS.register('h5dataset')
8
+ class H5Dataset:
9
+ pass
10
+
11
+ # Later
12
+ dataset_cls = DATASETS.get('h5dataset')
13
+ dataset = dataset_cls(path='data.h5')
14
+ """
15
+
16
+
17
+ class Registry:
18
+ """
19
+ Simple registry for mapping string names to classes.
20
+
21
+ Used for:
22
+ - Dataset types ('h5dataset', 'memdataset')
23
+ - Preprocessing types ('detrend', 'standardization')
24
+
25
+ This enables config-based instantiation.
26
+ """
27
+
28
+ def __init__(self, name: str):
29
+ """
30
+ Initialize registry.
31
+
32
+ Args:
33
+ name: Registry name (for error messages)
34
+ """
35
+ self.name = name
36
+ self._registry = {}
37
+
38
+ def register(self, name: str):
39
+ """
40
+ Decorator to register a class.
41
+
42
+ Args:
43
+ name: String identifier for the class
44
+
45
+ Returns:
46
+ Decorator function
47
+
48
+ Example:
49
+ >>> @DATASETS.register('h5dataset')
50
+ >>> class H5Dataset:
51
+ ... pass
52
+ """
53
+
54
+ def decorator(cls):
55
+ if name in self._registry:
56
+ raise ValueError(
57
+ f"'{name}' already registered in {self.name}. "
58
+ f"Existing: {self._registry[name]}, New: {cls}"
59
+ )
60
+ self._registry[name] = cls
61
+ return cls
62
+
63
+ return decorator
64
+
65
+ def get(self, name: str):
66
+ """
67
+ Get a registered class by name.
68
+
69
+ Args:
70
+ name: String identifier
71
+
72
+ Returns:
73
+ Registered class
74
+
75
+ Raises:
76
+ KeyError: If name not registered
77
+ """
78
+ if name not in self._registry:
79
+ raise KeyError(
80
+ f"'{name}' not found in {self.name}. "
81
+ f"Available: {list(self._registry.keys())}"
82
+ )
83
+ return self._registry[name]
84
+
85
+ def list_registered(self):
86
+ """List all registered names."""
87
+ return list(self._registry.keys())
88
+
89
+ def __contains__(self, name: str):
90
+ """Check if name is registered."""
91
+ return name in self._registry
92
+
93
+
94
+ # Global registries
95
+ DATASETS = Registry('DATASETS')
96
+ PREPROCESSING = Registry('PREPROCESSING')