copick-torch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,20 @@
1
+ __version__ = "0.2.0"
2
+
3
+ from copick_torch.augmentations import FourierAugment3D, MixupTransform
4
+ from copick_torch.copick import CopickDataset
5
+ from copick_torch.dataset import SimpleCopickDataset, SimpleDatasetMixin, SplicedMixupDataset
6
+ from copick_torch.logging import setup_logging
7
+ from copick_torch.minimal_dataset import MinimalCopickDataset
8
+ from copick_torch.samplers import ClassBalancedSampler
9
+
10
+ __all__ = [
11
+ "CopickDataset",
12
+ "SimpleCopickDataset",
13
+ "SimpleDatasetMixin",
14
+ "SplicedMixupDataset",
15
+ "MinimalCopickDataset",
16
+ "MixupTransform",
17
+ "FourierAugment3D",
18
+ "ClassBalancedSampler",
19
+ "setup_logging",
20
+ ]
@@ -0,0 +1,262 @@
1
+ """
2
+ Augmentations for 3D volumes based on MONAI transform interface.
3
+
4
+ This module provides MONAI-based implementations of augmentations for 3D tomographic data.
5
+ """
6
+
7
+ from typing import Optional, Sequence, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from monai.config.type_definitions import NdarrayOrTensor
12
+ from monai.transforms import (
13
+ Fourier,
14
+ MapTransform,
15
+ RandomizableTrait,
16
+ RandomizableTransform,
17
+ Transform,
18
+ )
19
+ from monai.transforms.utils import Fourier as FourierUtils
20
+ from monai.utils import convert_data_type, convert_to_dst_type, convert_to_tensor
21
+
22
+
23
+ class MixupTransform(RandomizableTransform):
24
+ """
25
+ Implements Mixup augmentation for 3D volumes based on MONAI transform interface.
26
+
27
+ Mixup is a data augmentation technique that creates virtual training examples
28
+ by mixing pairs of inputs and their labels with a random proportion.
29
+
30
+ Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
31
+ https://arxiv.org/abs/1710.09412
32
+ """
33
+
34
+ def __init__(self, alpha: float = 0.2, prob: float = 1.0):
35
+ """
36
+ Initialize the Mixup augmentation.
37
+
38
+ Args:
39
+ alpha: Parameter for Beta distribution. Higher values result in more mixing.
40
+ prob: Probability of applying the transform.
41
+ """
42
+ RandomizableTransform.__init__(self, prob)
43
+ self.alpha = alpha
44
+ self.lam = 1.0
45
+ self.index = None
46
+
47
+ def randomize(self, data=None) -> None:
48
+ """
49
+ Randomize the transform parameters.
50
+ """
51
+ super().randomize(None)
52
+ if not self._do_transform:
53
+ return None
54
+
55
+ if self.alpha > 0:
56
+ self.lam = np.random.beta(self.alpha, self.alpha)
57
+ else:
58
+ self.lam = 1.0
59
+
60
+ # Comment: Previous implementation had a bug that maximized lambda
61
+ # Ensure lambda is between 0 and 1
62
+ self.lam = min(max(self.lam, 0.0), 1.0)
63
+
64
+ def __call__(
65
+ self,
66
+ img: torch.Tensor,
67
+ randomize: bool = True,
68
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
69
+ """
70
+ Apply mixup augmentation to a batch of images and labels.
71
+
72
+ Args:
73
+ img: Tensor of shape [batch_size, channels, depth, height, width]
74
+ randomize: Whether to execute randomize function first, default to True.
75
+
76
+ Returns:
77
+ Tuple of (mixed_images, label_a, label_b, lam) where:
78
+ - mixed_images: The mixup result
79
+ - label_a: Original labels
80
+ - label_b: Mixed-in labels
81
+ - lam: Mixing coefficient from Beta distribution
82
+ """
83
+ if randomize:
84
+ self.randomize()
85
+
86
+ if not self._do_transform:
87
+ return img, img, img, 1.0
88
+
89
+ img = convert_to_tensor(img)
90
+ batch_size = img.shape[0]
91
+
92
+ # Generate random indices for mixing
93
+ self.index = torch.randperm(batch_size, device=img.device)
94
+
95
+ # Mix the images
96
+ mixed_images = self.lam * img + (1 - self.lam) * img[self.index]
97
+
98
+ # Return the mixed images and indices
99
+ return mixed_images, img, img[self.index], self.lam
100
+
101
+ @staticmethod
102
+ def mixup_criterion(criterion, pred, y_a, y_b, lam):
103
+ """
104
+ Apply mixup to the loss calculation.
105
+
106
+ Args:
107
+ criterion: Loss function
108
+ pred: Model predictions
109
+ y_a: First labels
110
+ y_b: Second (mixed-in) labels
111
+ lam: Mixing coefficient
112
+
113
+ Returns:
114
+ Mixed loss
115
+ """
116
+ return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
117
+
118
+
119
+ class FourierAugment3D(RandomizableTransform, Fourier):
120
+ """
121
+ Implements Fourier-based augmentation for 3D volumes based on MONAI transform interface.
122
+
123
+ This augmentation performs operations in the frequency domain, including
124
+ random frequency dropout (masking), phase noise injection, and intensity scaling.
125
+
126
+ It can help the model become more robust to various frequency distortions that
127
+ may occur in tomographic data.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ freq_mask_prob: float = 0.3,
133
+ phase_noise_std: float = 0.1,
134
+ intensity_scaling_range: Tuple[float, float] = (0.8, 1.2),
135
+ prob: float = 1.0,
136
+ ) -> None:
137
+ """
138
+ Initialize the Fourier domain augmentation.
139
+
140
+ Args:
141
+ freq_mask_prob: Probability of masking a frequency component
142
+ phase_noise_std: Standard deviation of Gaussian noise added to the phase
143
+ intensity_scaling_range: Range for random intensity scaling (min, max)
144
+ prob: Probability of applying the transform
145
+ """
146
+ RandomizableTransform.__init__(self, prob)
147
+ self.freq_mask_prob = freq_mask_prob
148
+ self.phase_noise_std = phase_noise_std
149
+ self.intensity_scaling_range = intensity_scaling_range
150
+
151
+ # Randomized parameters
152
+ self._mask = None
153
+ self._phase_noise = None
154
+ self._intensity_scale = None
155
+
156
+ def randomize(self, spatial_shape=None) -> None:
157
+ """
158
+ Randomize the transform parameters.
159
+ """
160
+ super().randomize(None)
161
+ if not self._do_transform or spatial_shape is None:
162
+ return None
163
+
164
+ # Randomize masking
165
+ if np.random.rand() < self.freq_mask_prob:
166
+ self._mask = torch.rand(spatial_shape, dtype=torch.float32) > self.freq_mask_prob
167
+ else:
168
+ self._mask = None
169
+
170
+ # Randomize phase noise
171
+ self._phase_noise = torch.randn(spatial_shape, dtype=torch.float32) * self.phase_noise_std
172
+
173
+ # Randomize intensity scaling
174
+ self._intensity_scale = np.random.uniform(
175
+ low=self.intensity_scaling_range[0],
176
+ high=self.intensity_scaling_range[1],
177
+ )
178
+
179
+ def __call__(self, volume: torch.Tensor, randomize: bool = True) -> torch.Tensor:
180
+ """
181
+ Apply Fourier domain augmentation to a volume.
182
+
183
+ Args:
184
+ volume: Tensor of shape [depth, height, width] or [channels, depth, height, width]
185
+ randomize: Whether to execute randomize function first, default to True.
186
+
187
+ Returns:
188
+ Augmented volume with same shape as input
189
+ """
190
+ if randomize:
191
+ # Get input shape for randomization
192
+ input_shape = volume.shape
193
+ spatial_shape = input_shape if len(input_shape) == 3 else input_shape[1:]
194
+ self.randomize(spatial_shape)
195
+
196
+ if not self._do_transform:
197
+ return volume
198
+
199
+ # Ensure volume is a torch tensor
200
+ volume = convert_to_tensor(volume)
201
+ is_channel_first = len(volume.shape) == 4
202
+
203
+ if is_channel_first:
204
+ # Process each channel independently with different random parameters
205
+ # to ensure channel diversity
206
+ output = []
207
+ for channel in range(volume.shape[0]):
208
+ # Re-randomize parameters for each channel to ensure diversity
209
+ if randomize:
210
+ self.randomize(volume[channel].shape)
211
+ aug_channel = self._apply_fourier_aug(volume[channel])
212
+ output.append(aug_channel)
213
+ return torch.stack(output)
214
+ else:
215
+ # Process 3D volume directly
216
+ return self._apply_fourier_aug(volume)
217
+
218
+ def _apply_fourier_aug(self, vol_tensor: torch.Tensor) -> torch.Tensor:
219
+ """
220
+ Apply Fourier augmentation to a single tensor (no channels).
221
+
222
+ Args:
223
+ vol_tensor: 3D tensor of shape [depth, height, width]
224
+
225
+ Returns:
226
+ Augmented tensor of same shape
227
+ """
228
+ device = vol_tensor.device
229
+
230
+ # Move randomized parameters to the same device
231
+ if self._mask is not None:
232
+ mask = self._mask.to(device)
233
+ phase_noise = self._phase_noise.to(device)
234
+
235
+ # FFT
236
+ f_volume = torch.fft.fftn(vol_tensor)
237
+ f_shifted = torch.fft.fftshift(f_volume)
238
+
239
+ # Magnitude and phase
240
+ magnitude = torch.abs(f_shifted)
241
+ phase = torch.angle(f_shifted)
242
+
243
+ # 1. Random frequency dropout (mask)
244
+ if self._mask is not None:
245
+ magnitude = magnitude * mask
246
+
247
+ # 2. Random phase noise
248
+ phase = phase + phase_noise
249
+
250
+ # 3. Combine magnitude and noisy phase
251
+ real = magnitude * torch.cos(phase)
252
+ imag = magnitude * torch.sin(phase)
253
+ f_augmented = torch.complex(real, imag)
254
+
255
+ # IFFT
256
+ f_ishifted = torch.fft.ifftshift(f_augmented)
257
+ augmented_volume = torch.fft.ifftn(f_ishifted).real # Discard imaginary part
258
+
259
+ # 4. Intensity scaling
260
+ augmented_volume *= self._intensity_scale
261
+
262
+ return augmented_volume