copick-torch 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- copick_torch/__init__.py +20 -0
- copick_torch/augmentations.py +262 -0
- copick_torch/copick.py +1273 -0
- copick_torch/dataset.py +1159 -0
- copick_torch/logging.py +9 -0
- copick_torch/minimal_dataset.py +776 -0
- copick_torch/samplers.py +59 -0
- copick_torch-0.2.0.dist-info/METADATA +372 -0
- copick_torch-0.2.0.dist-info/RECORD +12 -0
- copick_torch-0.2.0.dist-info/WHEEL +4 -0
- copick_torch-0.2.0.dist-info/entry_points.txt +2 -0
- copick_torch-0.2.0.dist-info/licenses/LICENSE +21 -0
copick_torch/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
__version__ = "0.2.0"
|
|
2
|
+
|
|
3
|
+
from copick_torch.augmentations import FourierAugment3D, MixupTransform
|
|
4
|
+
from copick_torch.copick import CopickDataset
|
|
5
|
+
from copick_torch.dataset import SimpleCopickDataset, SimpleDatasetMixin, SplicedMixupDataset
|
|
6
|
+
from copick_torch.logging import setup_logging
|
|
7
|
+
from copick_torch.minimal_dataset import MinimalCopickDataset
|
|
8
|
+
from copick_torch.samplers import ClassBalancedSampler
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CopickDataset",
|
|
12
|
+
"SimpleCopickDataset",
|
|
13
|
+
"SimpleDatasetMixin",
|
|
14
|
+
"SplicedMixupDataset",
|
|
15
|
+
"MinimalCopickDataset",
|
|
16
|
+
"MixupTransform",
|
|
17
|
+
"FourierAugment3D",
|
|
18
|
+
"ClassBalancedSampler",
|
|
19
|
+
"setup_logging",
|
|
20
|
+
]
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Augmentations for 3D volumes based on MONAI transform interface.
|
|
3
|
+
|
|
4
|
+
This module provides MONAI-based implementations of augmentations for 3D tomographic data.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Sequence, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from monai.config.type_definitions import NdarrayOrTensor
|
|
12
|
+
from monai.transforms import (
|
|
13
|
+
Fourier,
|
|
14
|
+
MapTransform,
|
|
15
|
+
RandomizableTrait,
|
|
16
|
+
RandomizableTransform,
|
|
17
|
+
Transform,
|
|
18
|
+
)
|
|
19
|
+
from monai.transforms.utils import Fourier as FourierUtils
|
|
20
|
+
from monai.utils import convert_data_type, convert_to_dst_type, convert_to_tensor
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MixupTransform(RandomizableTransform):
|
|
24
|
+
"""
|
|
25
|
+
Implements Mixup augmentation for 3D volumes based on MONAI transform interface.
|
|
26
|
+
|
|
27
|
+
Mixup is a data augmentation technique that creates virtual training examples
|
|
28
|
+
by mixing pairs of inputs and their labels with a random proportion.
|
|
29
|
+
|
|
30
|
+
Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
|
|
31
|
+
https://arxiv.org/abs/1710.09412
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, alpha: float = 0.2, prob: float = 1.0):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the Mixup augmentation.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
alpha: Parameter for Beta distribution. Higher values result in more mixing.
|
|
40
|
+
prob: Probability of applying the transform.
|
|
41
|
+
"""
|
|
42
|
+
RandomizableTransform.__init__(self, prob)
|
|
43
|
+
self.alpha = alpha
|
|
44
|
+
self.lam = 1.0
|
|
45
|
+
self.index = None
|
|
46
|
+
|
|
47
|
+
def randomize(self, data=None) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Randomize the transform parameters.
|
|
50
|
+
"""
|
|
51
|
+
super().randomize(None)
|
|
52
|
+
if not self._do_transform:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
if self.alpha > 0:
|
|
56
|
+
self.lam = np.random.beta(self.alpha, self.alpha)
|
|
57
|
+
else:
|
|
58
|
+
self.lam = 1.0
|
|
59
|
+
|
|
60
|
+
# Comment: Previous implementation had a bug that maximized lambda
|
|
61
|
+
# Ensure lambda is between 0 and 1
|
|
62
|
+
self.lam = min(max(self.lam, 0.0), 1.0)
|
|
63
|
+
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
img: torch.Tensor,
|
|
67
|
+
randomize: bool = True,
|
|
68
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
|
|
69
|
+
"""
|
|
70
|
+
Apply mixup augmentation to a batch of images and labels.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
img: Tensor of shape [batch_size, channels, depth, height, width]
|
|
74
|
+
randomize: Whether to execute randomize function first, default to True.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Tuple of (mixed_images, label_a, label_b, lam) where:
|
|
78
|
+
- mixed_images: The mixup result
|
|
79
|
+
- label_a: Original labels
|
|
80
|
+
- label_b: Mixed-in labels
|
|
81
|
+
- lam: Mixing coefficient from Beta distribution
|
|
82
|
+
"""
|
|
83
|
+
if randomize:
|
|
84
|
+
self.randomize()
|
|
85
|
+
|
|
86
|
+
if not self._do_transform:
|
|
87
|
+
return img, img, img, 1.0
|
|
88
|
+
|
|
89
|
+
img = convert_to_tensor(img)
|
|
90
|
+
batch_size = img.shape[0]
|
|
91
|
+
|
|
92
|
+
# Generate random indices for mixing
|
|
93
|
+
self.index = torch.randperm(batch_size, device=img.device)
|
|
94
|
+
|
|
95
|
+
# Mix the images
|
|
96
|
+
mixed_images = self.lam * img + (1 - self.lam) * img[self.index]
|
|
97
|
+
|
|
98
|
+
# Return the mixed images and indices
|
|
99
|
+
return mixed_images, img, img[self.index], self.lam
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def mixup_criterion(criterion, pred, y_a, y_b, lam):
|
|
103
|
+
"""
|
|
104
|
+
Apply mixup to the loss calculation.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
criterion: Loss function
|
|
108
|
+
pred: Model predictions
|
|
109
|
+
y_a: First labels
|
|
110
|
+
y_b: Second (mixed-in) labels
|
|
111
|
+
lam: Mixing coefficient
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Mixed loss
|
|
115
|
+
"""
|
|
116
|
+
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class FourierAugment3D(RandomizableTransform, Fourier):
|
|
120
|
+
"""
|
|
121
|
+
Implements Fourier-based augmentation for 3D volumes based on MONAI transform interface.
|
|
122
|
+
|
|
123
|
+
This augmentation performs operations in the frequency domain, including
|
|
124
|
+
random frequency dropout (masking), phase noise injection, and intensity scaling.
|
|
125
|
+
|
|
126
|
+
It can help the model become more robust to various frequency distortions that
|
|
127
|
+
may occur in tomographic data.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
freq_mask_prob: float = 0.3,
|
|
133
|
+
phase_noise_std: float = 0.1,
|
|
134
|
+
intensity_scaling_range: Tuple[float, float] = (0.8, 1.2),
|
|
135
|
+
prob: float = 1.0,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Initialize the Fourier domain augmentation.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
freq_mask_prob: Probability of masking a frequency component
|
|
142
|
+
phase_noise_std: Standard deviation of Gaussian noise added to the phase
|
|
143
|
+
intensity_scaling_range: Range for random intensity scaling (min, max)
|
|
144
|
+
prob: Probability of applying the transform
|
|
145
|
+
"""
|
|
146
|
+
RandomizableTransform.__init__(self, prob)
|
|
147
|
+
self.freq_mask_prob = freq_mask_prob
|
|
148
|
+
self.phase_noise_std = phase_noise_std
|
|
149
|
+
self.intensity_scaling_range = intensity_scaling_range
|
|
150
|
+
|
|
151
|
+
# Randomized parameters
|
|
152
|
+
self._mask = None
|
|
153
|
+
self._phase_noise = None
|
|
154
|
+
self._intensity_scale = None
|
|
155
|
+
|
|
156
|
+
def randomize(self, spatial_shape=None) -> None:
|
|
157
|
+
"""
|
|
158
|
+
Randomize the transform parameters.
|
|
159
|
+
"""
|
|
160
|
+
super().randomize(None)
|
|
161
|
+
if not self._do_transform or spatial_shape is None:
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
# Randomize masking
|
|
165
|
+
if np.random.rand() < self.freq_mask_prob:
|
|
166
|
+
self._mask = torch.rand(spatial_shape, dtype=torch.float32) > self.freq_mask_prob
|
|
167
|
+
else:
|
|
168
|
+
self._mask = None
|
|
169
|
+
|
|
170
|
+
# Randomize phase noise
|
|
171
|
+
self._phase_noise = torch.randn(spatial_shape, dtype=torch.float32) * self.phase_noise_std
|
|
172
|
+
|
|
173
|
+
# Randomize intensity scaling
|
|
174
|
+
self._intensity_scale = np.random.uniform(
|
|
175
|
+
low=self.intensity_scaling_range[0],
|
|
176
|
+
high=self.intensity_scaling_range[1],
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def __call__(self, volume: torch.Tensor, randomize: bool = True) -> torch.Tensor:
|
|
180
|
+
"""
|
|
181
|
+
Apply Fourier domain augmentation to a volume.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
volume: Tensor of shape [depth, height, width] or [channels, depth, height, width]
|
|
185
|
+
randomize: Whether to execute randomize function first, default to True.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Augmented volume with same shape as input
|
|
189
|
+
"""
|
|
190
|
+
if randomize:
|
|
191
|
+
# Get input shape for randomization
|
|
192
|
+
input_shape = volume.shape
|
|
193
|
+
spatial_shape = input_shape if len(input_shape) == 3 else input_shape[1:]
|
|
194
|
+
self.randomize(spatial_shape)
|
|
195
|
+
|
|
196
|
+
if not self._do_transform:
|
|
197
|
+
return volume
|
|
198
|
+
|
|
199
|
+
# Ensure volume is a torch tensor
|
|
200
|
+
volume = convert_to_tensor(volume)
|
|
201
|
+
is_channel_first = len(volume.shape) == 4
|
|
202
|
+
|
|
203
|
+
if is_channel_first:
|
|
204
|
+
# Process each channel independently with different random parameters
|
|
205
|
+
# to ensure channel diversity
|
|
206
|
+
output = []
|
|
207
|
+
for channel in range(volume.shape[0]):
|
|
208
|
+
# Re-randomize parameters for each channel to ensure diversity
|
|
209
|
+
if randomize:
|
|
210
|
+
self.randomize(volume[channel].shape)
|
|
211
|
+
aug_channel = self._apply_fourier_aug(volume[channel])
|
|
212
|
+
output.append(aug_channel)
|
|
213
|
+
return torch.stack(output)
|
|
214
|
+
else:
|
|
215
|
+
# Process 3D volume directly
|
|
216
|
+
return self._apply_fourier_aug(volume)
|
|
217
|
+
|
|
218
|
+
def _apply_fourier_aug(self, vol_tensor: torch.Tensor) -> torch.Tensor:
|
|
219
|
+
"""
|
|
220
|
+
Apply Fourier augmentation to a single tensor (no channels).
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
vol_tensor: 3D tensor of shape [depth, height, width]
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Augmented tensor of same shape
|
|
227
|
+
"""
|
|
228
|
+
device = vol_tensor.device
|
|
229
|
+
|
|
230
|
+
# Move randomized parameters to the same device
|
|
231
|
+
if self._mask is not None:
|
|
232
|
+
mask = self._mask.to(device)
|
|
233
|
+
phase_noise = self._phase_noise.to(device)
|
|
234
|
+
|
|
235
|
+
# FFT
|
|
236
|
+
f_volume = torch.fft.fftn(vol_tensor)
|
|
237
|
+
f_shifted = torch.fft.fftshift(f_volume)
|
|
238
|
+
|
|
239
|
+
# Magnitude and phase
|
|
240
|
+
magnitude = torch.abs(f_shifted)
|
|
241
|
+
phase = torch.angle(f_shifted)
|
|
242
|
+
|
|
243
|
+
# 1. Random frequency dropout (mask)
|
|
244
|
+
if self._mask is not None:
|
|
245
|
+
magnitude = magnitude * mask
|
|
246
|
+
|
|
247
|
+
# 2. Random phase noise
|
|
248
|
+
phase = phase + phase_noise
|
|
249
|
+
|
|
250
|
+
# 3. Combine magnitude and noisy phase
|
|
251
|
+
real = magnitude * torch.cos(phase)
|
|
252
|
+
imag = magnitude * torch.sin(phase)
|
|
253
|
+
f_augmented = torch.complex(real, imag)
|
|
254
|
+
|
|
255
|
+
# IFFT
|
|
256
|
+
f_ishifted = torch.fft.ifftshift(f_augmented)
|
|
257
|
+
augmented_volume = torch.fft.ifftn(f_ishifted).real # Discard imaginary part
|
|
258
|
+
|
|
259
|
+
# 4. Intensity scaling
|
|
260
|
+
augmented_volume *= self._intensity_scale
|
|
261
|
+
|
|
262
|
+
return augmented_volume
|