careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +25 -17
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/architectures/lvae_model.py +0 -4
- careamics/config/configuration_factory.py +480 -177
- careamics/config/configuration_model.py +1 -2
- careamics/config/data_model.py +1 -15
- careamics/config/fcn_algorithm_model.py +14 -9
- careamics/config/likelihood_model.py +21 -4
- careamics/config/nm_model.py +31 -5
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +2 -36
- careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
- careamics/lightning/lightning_module.py +10 -8
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/loss_factory.py +3 -3
- careamics/losses/lvae/losses.py +2 -2
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
- careamics/lvae_training/dataset/lc_dataset.py +28 -20
- careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +4 -2
- careamics/model_io/bmz_io.py +6 -5
- careamics/models/lvae/likelihoods.py +18 -9
- careamics/models/lvae/lvae.py +12 -16
- careamics/models/lvae/noise_models.py +1 -1
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +204 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
- careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
- careamics/lvae_training/dataset/data_utils.py +0 -701
- careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,63 +1,13 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class DataType(Enum):
|
|
9
|
-
MNIST = 0
|
|
10
|
-
Places365 = 1
|
|
11
|
-
NotMNIST = 2
|
|
12
|
-
OptiMEM100_014 = 3
|
|
13
|
-
CustomSinosoid = 4
|
|
14
|
-
Prevedel_EMBL = 5
|
|
15
|
-
AllenCellMito = 6
|
|
16
|
-
SeparateTiffData = 7
|
|
17
|
-
CustomSinosoidThreeCurve = 8
|
|
18
|
-
SemiSupBloodVesselsEMBL = 9
|
|
19
|
-
Pavia2 = 10
|
|
20
|
-
Pavia2VanillaSplitting = 11
|
|
21
|
-
ExpansionMicroscopyMitoTub = 12
|
|
22
|
-
ShroffMitoEr = 13
|
|
23
|
-
HTIba1Ki67 = 14
|
|
24
|
-
BSD68 = 15
|
|
25
|
-
BioSR_MRC = 16
|
|
26
|
-
TavernaSox2Golgi = 17
|
|
27
|
-
Dao3Channel = 18
|
|
28
|
-
ExpMicroscopyV2 = 19
|
|
29
|
-
Dao3ChannelWithInput = 20
|
|
30
|
-
TavernaSox2GolgiV2 = 21
|
|
31
|
-
TwoDset = 22
|
|
32
|
-
PredictedTiffData = 23
|
|
33
|
-
Pavia3SeqData = 24
|
|
34
|
-
# Here, we have 16 splitting tasks.
|
|
35
|
-
NicolaData = 25
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class DataSplitType(Enum):
|
|
39
|
-
All = 0
|
|
40
|
-
Train = 1
|
|
41
|
-
Val = 2
|
|
42
|
-
Test = 3
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class GridAlignement(Enum):
|
|
46
|
-
"""
|
|
47
|
-
A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
|
|
48
|
-
On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
|
|
49
|
-
In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
|
|
50
|
-
In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
LeftTop = 0
|
|
54
|
-
Center = 1
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# TODO: for all bool params check if they are taking different values in Disentangle repo
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
from .types import DataType, DataSplitType, TilingMode
|
|
6
|
+
|
|
7
|
+
|
|
58
8
|
# TODO: check if any bool logic can be removed
|
|
59
|
-
class
|
|
60
|
-
model_config = ConfigDict(validate_assignment=True)
|
|
9
|
+
class DatasetConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
|
61
11
|
|
|
62
12
|
data_type: Optional[DataType]
|
|
63
13
|
"""Type of the dataset, should be one of DataType"""
|
|
@@ -132,15 +82,10 @@ class VaeDatasetConfig(BaseModel):
|
|
|
132
82
|
# TODO: why is this not used?
|
|
133
83
|
enable_rotation_aug: Optional[bool] = False
|
|
134
84
|
|
|
135
|
-
grid_alignment: GridAlignement = GridAlignement.LeftTop
|
|
136
|
-
|
|
137
85
|
max_val: Optional[float] = None
|
|
138
86
|
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
139
87
|
externally set for val and test splits."""
|
|
140
88
|
|
|
141
|
-
trim_boundary: Optional[bool] = True
|
|
142
|
-
"""Whether to trim boundary of the image"""
|
|
143
|
-
|
|
144
89
|
overlapping_padding_kwargs: Any = None
|
|
145
90
|
"""Parameters for np.pad method"""
|
|
146
91
|
|
|
@@ -157,23 +102,22 @@ class VaeDatasetConfig(BaseModel):
|
|
|
157
102
|
train_aug_rotate: Optional[bool] = False
|
|
158
103
|
enable_random_cropping: Optional[bool] = True
|
|
159
104
|
|
|
160
|
-
# TODO: not used?
|
|
161
105
|
multiscale_lowres_count: Optional[int] = None
|
|
106
|
+
"""Number of LC scales"""
|
|
107
|
+
|
|
108
|
+
tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
|
|
109
|
+
|
|
110
|
+
target_separate_normalization: Optional[bool] = True
|
|
111
|
+
|
|
112
|
+
mode_3D: Optional[bool] = False
|
|
113
|
+
"""If training in 3D mode or not"""
|
|
114
|
+
|
|
115
|
+
trainig_datausage_fraction: Optional[float] = 1.0
|
|
116
|
+
|
|
117
|
+
validtarget_random_fraction: Optional[float] = None
|
|
118
|
+
|
|
119
|
+
validation_datausage_fraction: Optional[float] = 1.0
|
|
120
|
+
|
|
121
|
+
random_flip_z_3D: Optional[bool] = False
|
|
162
122
|
|
|
163
|
-
|
|
164
|
-
@property
|
|
165
|
-
def padding_kwargs(self) -> dict:
|
|
166
|
-
kwargs_dict = {}
|
|
167
|
-
padding_kwargs = {}
|
|
168
|
-
if (
|
|
169
|
-
self.multiscale_lowres_count is not None
|
|
170
|
-
and self.multiscale_lowres_count is not None
|
|
171
|
-
):
|
|
172
|
-
# Get padding attributes
|
|
173
|
-
if "padding_kwargs" not in kwargs_dict:
|
|
174
|
-
padding_kwargs = {}
|
|
175
|
-
padding_kwargs["mode"] = "constant"
|
|
176
|
-
padding_kwargs["constant_values"] = 0
|
|
177
|
-
else:
|
|
178
|
-
padding_kwargs = kwargs_dict.pop("padding_kwargs")
|
|
179
|
-
return padding_kwargs
|
|
123
|
+
padding_kwargs: Optional[dict] = None
|
|
@@ -2,34 +2,37 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Tuple, Union
|
|
5
|
+
from typing import Tuple, Union, Callable
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from skimage.transform import resize
|
|
9
9
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
10
|
+
from .config import DatasetConfig
|
|
11
|
+
from .multich_dataset import MultiChDloader
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class LCMultiChDloader(MultiChDloader):
|
|
15
|
-
|
|
16
15
|
def __init__(
|
|
17
16
|
self,
|
|
18
|
-
data_config:
|
|
17
|
+
data_config: DatasetConfig,
|
|
19
18
|
fpath: str,
|
|
19
|
+
load_data_fn: Callable,
|
|
20
20
|
val_fraction=None,
|
|
21
21
|
test_fraction=None,
|
|
22
22
|
):
|
|
23
|
-
"""
|
|
24
|
-
Args:
|
|
25
|
-
num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
|
|
26
|
-
highest resolution.
|
|
27
|
-
"""
|
|
28
23
|
self._padding_kwargs = (
|
|
29
24
|
data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
30
25
|
)
|
|
31
26
|
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
32
27
|
|
|
28
|
+
super().__init__(
|
|
29
|
+
data_config,
|
|
30
|
+
fpath,
|
|
31
|
+
load_data_fn=load_data_fn,
|
|
32
|
+
val_fraction=val_fraction,
|
|
33
|
+
test_fraction=test_fraction,
|
|
34
|
+
)
|
|
35
|
+
|
|
33
36
|
if data_config.overlapping_padding_kwargs is not None:
|
|
34
37
|
assert (
|
|
35
38
|
self._padding_kwargs == data_config.overlapping_padding_kwargs
|
|
@@ -37,21 +40,21 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
37
40
|
It should be so since we just use overlapping_padding_kwargs when it is not None"
|
|
38
41
|
|
|
39
42
|
else:
|
|
40
|
-
|
|
43
|
+
self._overlapping_padding_kwargs = data_config.padding_kwargs
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
)
|
|
45
|
-
self.num_scales = data_config.num_scales
|
|
46
|
-
assert self.num_scales is not None
|
|
45
|
+
self.multiscale_lowres_count = data_config.multiscale_lowres_count
|
|
46
|
+
assert self.multiscale_lowres_count is not None
|
|
47
47
|
self._scaled_data = [self._data]
|
|
48
48
|
self._scaled_noise_data = [self._noise_data]
|
|
49
49
|
|
|
50
|
-
assert
|
|
50
|
+
assert (
|
|
51
|
+
isinstance(self.multiscale_lowres_count, int)
|
|
52
|
+
and self.multiscale_lowres_count >= 1
|
|
53
|
+
)
|
|
51
54
|
assert isinstance(self._padding_kwargs, dict)
|
|
52
55
|
assert "mode" in self._padding_kwargs
|
|
53
56
|
|
|
54
|
-
for _ in range(1, self.
|
|
57
|
+
for _ in range(1, self.multiscale_lowres_count):
|
|
55
58
|
shape = self._scaled_data[-1].shape
|
|
56
59
|
assert len(shape) == 4
|
|
57
60
|
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
|
|
@@ -173,7 +176,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
173
176
|
allres_versions = {
|
|
174
177
|
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
|
|
175
178
|
}
|
|
176
|
-
for scale_idx in range(1, self.
|
|
179
|
+
for scale_idx in range(1, self.multiscale_lowres_count):
|
|
177
180
|
# Returning the image of the lower resolution
|
|
178
181
|
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
|
|
179
182
|
|
|
@@ -227,6 +230,9 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
227
230
|
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
228
231
|
input_tuples = []
|
|
229
232
|
for x in img_tuples:
|
|
233
|
+
x = (
|
|
234
|
+
x.copy()
|
|
235
|
+
) # to avoid changing the original image since it is later used for target
|
|
230
236
|
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
|
|
231
237
|
x[0] = x[0] + noise_tuples[0] * factor
|
|
232
238
|
input_tuples.append(x)
|
|
@@ -246,7 +252,9 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
246
252
|
|
|
247
253
|
target = self._compute_target(target_tuples, alpha)
|
|
248
254
|
|
|
249
|
-
|
|
255
|
+
norm_target = self.normalize_target(target)
|
|
256
|
+
|
|
257
|
+
output = [inp, norm_target]
|
|
250
258
|
|
|
251
259
|
if self._return_alpha:
|
|
252
260
|
output.append(alpha)
|
|
@@ -2,39 +2,39 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Tuple, Union
|
|
5
|
+
from typing import Tuple, Union, Callable
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
|
-
from .
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from .vae_data_config import VaeDatasetConfig, DataSplitType, GridAlignement
|
|
9
|
+
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
10
|
+
from .utils.index_manager import GridIndexManager
|
|
11
|
+
from .utils.index_switcher import IndexSwitcher
|
|
12
|
+
from .config import DatasetConfig
|
|
13
|
+
from .types import DataSplitType, TilingMode
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class MultiChDloader:
|
|
18
17
|
def __init__(
|
|
19
18
|
self,
|
|
20
|
-
data_config:
|
|
19
|
+
data_config: DatasetConfig,
|
|
21
20
|
fpath: str,
|
|
21
|
+
load_data_fn: Callable,
|
|
22
22
|
val_fraction: float = None,
|
|
23
23
|
test_fraction: float = None,
|
|
24
24
|
):
|
|
25
25
|
""" """
|
|
26
26
|
self._data_type = data_config.data_type
|
|
27
27
|
self._fpath = fpath
|
|
28
|
-
self._data = self.
|
|
28
|
+
self._data = self._noise_data = None
|
|
29
29
|
self.Z = 1
|
|
30
|
-
self.
|
|
31
|
-
|
|
32
|
-
|
|
30
|
+
self._5Ddata = False
|
|
31
|
+
self._tiling_mode = data_config.tiling_mode
|
|
33
32
|
# by default, if the noise is present, add it to the input and target.
|
|
34
33
|
self._disable_noise = False # to add synthetic noise
|
|
35
34
|
self._poisson_noise_factor = None
|
|
36
35
|
self._train_index_switcher = None
|
|
37
36
|
self._depth3D = data_config.depth3D
|
|
37
|
+
self._mode_3D = data_config.mode_3D
|
|
38
38
|
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
39
39
|
self._input_is_sum = data_config.input_is_sum
|
|
40
40
|
self._num_channels = data_config.num_channels
|
|
@@ -42,20 +42,21 @@ class MultiChDloader:
|
|
|
42
42
|
self._tar_idx_list = data_config.target_idx_list
|
|
43
43
|
|
|
44
44
|
if data_config.datasplit_type == DataSplitType.Train:
|
|
45
|
-
self._datausage_fraction =
|
|
45
|
+
self._datausage_fraction = data_config.trainig_datausage_fraction
|
|
46
46
|
# assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
|
|
47
|
-
self._validtarget_rand_fract =
|
|
47
|
+
self._validtarget_rand_fract = data_config.validtarget_random_fraction
|
|
48
48
|
# self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
|
|
49
49
|
# self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
|
|
50
50
|
# self._idx_count = 0
|
|
51
51
|
elif data_config.datasplit_type == DataSplitType.Val:
|
|
52
|
-
self._datausage_fraction =
|
|
52
|
+
self._datausage_fraction = data_config.validation_datausage_fraction
|
|
53
53
|
else:
|
|
54
54
|
self._datausage_fraction = 1.0
|
|
55
55
|
|
|
56
56
|
self.load_data(
|
|
57
57
|
data_config,
|
|
58
58
|
data_config.datasplit_type,
|
|
59
|
+
load_data_fn=load_data_fn,
|
|
59
60
|
val_fraction=val_fraction,
|
|
60
61
|
test_fraction=test_fraction,
|
|
61
62
|
allow_generation=data_config.allow_generation,
|
|
@@ -70,18 +71,8 @@ class MultiChDloader:
|
|
|
70
71
|
|
|
71
72
|
self._background_values = None
|
|
72
73
|
|
|
73
|
-
self._grid_alignment = data_config.grid_alignment
|
|
74
74
|
self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
|
|
75
|
-
if self.
|
|
76
|
-
assert (
|
|
77
|
-
self._overlapping_padding_kwargs is None
|
|
78
|
-
or data_config.multiscale_lowres_count is not None
|
|
79
|
-
), "Padding is not used with this alignement style"
|
|
80
|
-
elif self._grid_alignment == GridAlignement.Center:
|
|
81
|
-
assert (
|
|
82
|
-
self._overlapping_padding_kwargs is not None
|
|
83
|
-
), "With Center grid alignment, padding is needed."
|
|
84
|
-
if self._trim_boundary:
|
|
75
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
85
76
|
if (
|
|
86
77
|
self._overlapping_padding_kwargs is None
|
|
87
78
|
or data_config.multiscale_lowres_count is not None
|
|
@@ -144,7 +135,6 @@ class MultiChDloader:
|
|
|
144
135
|
)
|
|
145
136
|
data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
146
137
|
# NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
147
|
-
# TODO: missing import, needs fixing asap!
|
|
148
138
|
self._empty_patch_fetcher = EmptyPatchFetcher(
|
|
149
139
|
self.idx_manager,
|
|
150
140
|
self._img_sz,
|
|
@@ -161,14 +151,18 @@ class MultiChDloader:
|
|
|
161
151
|
self._mean = None
|
|
162
152
|
self._std = None
|
|
163
153
|
self._use_one_mu_std = data_config.use_one_mu_std
|
|
164
|
-
|
|
165
|
-
self._target_separate_normalization =
|
|
154
|
+
|
|
155
|
+
self._target_separate_normalization = data_config.target_separate_normalization
|
|
166
156
|
|
|
167
157
|
self._enable_rotation = data_config.enable_rotation_aug
|
|
158
|
+
flipz_3D = data_config.random_flip_z_3D
|
|
159
|
+
self._flipz_3D = flipz_3D and self._enable_rotation
|
|
160
|
+
|
|
168
161
|
self._enable_random_cropping = data_config.enable_random_cropping
|
|
169
162
|
self._uncorrelated_channels = (
|
|
170
163
|
data_config.uncorrelated_channels and self._is_train
|
|
171
164
|
)
|
|
165
|
+
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
172
166
|
assert self._is_train or self._uncorrelated_channels is False
|
|
173
167
|
assert (
|
|
174
168
|
self._enable_random_cropping is True or self._uncorrelated_channels is False
|
|
@@ -177,9 +171,9 @@ class MultiChDloader:
|
|
|
177
171
|
|
|
178
172
|
self._rotation_transform = None
|
|
179
173
|
if self._enable_rotation:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
174
|
+
# TODO: fix this import
|
|
175
|
+
import albumentations as A
|
|
176
|
+
|
|
183
177
|
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
184
178
|
|
|
185
179
|
# TODO: remove print log messages
|
|
@@ -203,11 +197,12 @@ class MultiChDloader:
|
|
|
203
197
|
self,
|
|
204
198
|
data_config,
|
|
205
199
|
datasplit_type,
|
|
200
|
+
load_data_fn: Callable,
|
|
206
201
|
val_fraction=None,
|
|
207
202
|
test_fraction=None,
|
|
208
203
|
allow_generation=None,
|
|
209
204
|
):
|
|
210
|
-
self._data =
|
|
205
|
+
self._data = load_data_fn(
|
|
211
206
|
data_config,
|
|
212
207
|
self._fpath,
|
|
213
208
|
datasplit_type,
|
|
@@ -215,7 +210,9 @@ class MultiChDloader:
|
|
|
215
210
|
test_fraction=test_fraction,
|
|
216
211
|
allow_generation=allow_generation,
|
|
217
212
|
)
|
|
213
|
+
self._loaded_data_preprocessing(data_config)
|
|
218
214
|
|
|
215
|
+
def _loaded_data_preprocessing(self, data_config):
|
|
219
216
|
old_shape = self._data.shape
|
|
220
217
|
if self._datausage_fraction < 1.0:
|
|
221
218
|
framepixelcount = np.prod(self._data.shape[1:3])
|
|
@@ -239,10 +236,7 @@ class MultiChDloader:
|
|
|
239
236
|
if data_config.poisson_noise_factor > 0:
|
|
240
237
|
self._poisson_noise_factor = data_config.poisson_noise_factor
|
|
241
238
|
msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
|
|
242
|
-
self._data = (
|
|
243
|
-
np.random.poisson(self._data / self._poisson_noise_factor)
|
|
244
|
-
* self._poisson_noise_factor
|
|
245
|
-
)
|
|
239
|
+
self._data = np.random.poisson(self._data / self._poisson_noise_factor)
|
|
246
240
|
|
|
247
241
|
if data_config.enable_gaussian_noise:
|
|
248
242
|
synthetic_scale = data_config.synthetic_gaussian_scale
|
|
@@ -257,7 +251,13 @@ class MultiChDloader:
|
|
|
257
251
|
self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
|
|
258
252
|
print(msg)
|
|
259
253
|
|
|
260
|
-
|
|
254
|
+
if len(self._data.shape) == 5:
|
|
255
|
+
if self._mode_3D:
|
|
256
|
+
self._5Ddata = True
|
|
257
|
+
else:
|
|
258
|
+
assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
|
|
259
|
+
self._data = self._data.reshape(-1, *self._data.shape[2:])
|
|
260
|
+
|
|
261
261
|
if self._5Ddata:
|
|
262
262
|
self.Z = self._data.shape[1]
|
|
263
263
|
|
|
@@ -373,18 +373,28 @@ class MultiChDloader:
|
|
|
373
373
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
374
374
|
)
|
|
375
375
|
|
|
376
|
-
def get_idx_manager_shapes(
|
|
376
|
+
def get_idx_manager_shapes(
|
|
377
|
+
self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
|
|
378
|
+
):
|
|
377
379
|
numC = self._data.shape[-1]
|
|
378
380
|
if self._5Ddata:
|
|
379
|
-
grid_shape = (1, 1, grid_size, grid_size, numC)
|
|
380
381
|
patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
|
|
382
|
+
if isinstance(grid_size, int):
|
|
383
|
+
grid_shape = (1, 1, grid_size, grid_size, numC)
|
|
384
|
+
else:
|
|
385
|
+
assert len(grid_size) == 3
|
|
386
|
+
assert all(
|
|
387
|
+
[g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
|
|
388
|
+
), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
|
|
389
|
+
grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
|
|
381
390
|
else:
|
|
391
|
+
assert isinstance(grid_size, int)
|
|
382
392
|
grid_shape = (1, grid_size, grid_size, numC)
|
|
383
393
|
patch_shape = (1, patch_size, patch_size, numC)
|
|
384
394
|
|
|
385
395
|
return patch_shape, grid_shape
|
|
386
396
|
|
|
387
|
-
def set_img_sz(self, image_size, grid_size):
|
|
397
|
+
def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
|
|
388
398
|
"""
|
|
389
399
|
If one wants to change the image size on the go, then this can be used.
|
|
390
400
|
Args:
|
|
@@ -400,7 +410,7 @@ class MultiChDloader:
|
|
|
400
410
|
self._img_sz, self._grid_sz
|
|
401
411
|
)
|
|
402
412
|
self.idx_manager = GridIndexManager(
|
|
403
|
-
shape, grid_shape, patch_shape, self.
|
|
413
|
+
shape, grid_shape, patch_shape, self._tiling_mode
|
|
404
414
|
)
|
|
405
415
|
# self.set_repeat_factor()
|
|
406
416
|
|
|
@@ -432,10 +442,13 @@ class MultiChDloader:
|
|
|
432
442
|
dim_sizes = ",".join([str(x) for x in dim_sizes])
|
|
433
443
|
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
434
444
|
msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
|
|
435
|
-
msg += f" TrimB:{self.
|
|
445
|
+
msg += f" TrimB:{self._tiling_mode}"
|
|
436
446
|
# msg += f' NormInp:{self._normalized_input}'
|
|
437
447
|
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
438
448
|
msg += f" Rot:{self._enable_rotation}"
|
|
449
|
+
if self._flipz_3D:
|
|
450
|
+
msg += f" FlipZ:{self._flipz_3D}"
|
|
451
|
+
|
|
439
452
|
msg += f" RandCrop:{self._enable_random_cropping}"
|
|
440
453
|
msg += f" Channel:{self._num_channels}"
|
|
441
454
|
# msg += f' Q:{self._quantile}'
|
|
@@ -467,7 +480,7 @@ class MultiChDloader:
|
|
|
467
480
|
patch_start_loc = self._get_random_hw(h, w)
|
|
468
481
|
if self._5Ddata:
|
|
469
482
|
patch_start_loc = (
|
|
470
|
-
np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
|
|
483
|
+
np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
|
|
471
484
|
) + patch_start_loc
|
|
472
485
|
else:
|
|
473
486
|
patch_start_loc = self._get_deterministic_loc(index)
|
|
@@ -486,7 +499,7 @@ class MultiChDloader:
|
|
|
486
499
|
)
|
|
487
500
|
|
|
488
501
|
def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
|
|
489
|
-
if self.
|
|
502
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
490
503
|
# In training, this is used.
|
|
491
504
|
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
492
505
|
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
@@ -625,6 +638,18 @@ class MultiChDloader:
|
|
|
625
638
|
normalized_imgs.append(img)
|
|
626
639
|
return tuple(normalized_imgs)
|
|
627
640
|
|
|
641
|
+
def normalize_input(self, x):
|
|
642
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
643
|
+
mean_ = mean_dict["input"].mean()
|
|
644
|
+
std_ = std_dict["input"].mean()
|
|
645
|
+
return (x - mean_) / std_
|
|
646
|
+
|
|
647
|
+
def normalize_target(self, target):
|
|
648
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
649
|
+
mean_ = mean_dict["target"].squeeze(0)
|
|
650
|
+
std_ = std_dict["target"].squeeze(0)
|
|
651
|
+
return (target - mean_) / std_
|
|
652
|
+
|
|
628
653
|
def get_grid_size(self):
|
|
629
654
|
return self._grid_sz
|
|
630
655
|
|
|
@@ -906,28 +931,39 @@ class MultiChDloader:
|
|
|
906
931
|
return rotated_img_tuples, rotated_noise_tuples
|
|
907
932
|
|
|
908
933
|
def _rotate(self, img_tuples, noise_tuples):
|
|
909
|
-
|
|
934
|
+
|
|
935
|
+
if self._5Ddata:
|
|
910
936
|
return self._rotate3D(img_tuples, noise_tuples)
|
|
911
937
|
else:
|
|
912
938
|
return self._rotate2D(img_tuples, noise_tuples)
|
|
913
939
|
|
|
914
940
|
def _rotate3D(self, img_tuples, noise_tuples):
|
|
915
941
|
img_kwargs = {}
|
|
942
|
+
# random flip in z direction
|
|
943
|
+
flip_z = self._flipz_3D and np.random.rand() < 0.5
|
|
916
944
|
for i, img in enumerate(img_tuples):
|
|
917
945
|
for j in range(self._depth3D):
|
|
918
946
|
for k in range(len(img)):
|
|
919
|
-
|
|
947
|
+
if flip_z:
|
|
948
|
+
z_idx = self._depth3D - 1 - j
|
|
949
|
+
else:
|
|
950
|
+
z_idx = j
|
|
951
|
+
img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
|
|
920
952
|
|
|
921
953
|
noise_kwargs = {}
|
|
922
954
|
for i, nimg in enumerate(noise_tuples):
|
|
923
955
|
for j in range(self._depth3D):
|
|
924
956
|
for k in range(len(nimg)):
|
|
925
|
-
|
|
957
|
+
if flip_z:
|
|
958
|
+
z_idx = self._depth3D - 1 - j
|
|
959
|
+
else:
|
|
960
|
+
z_idx = j
|
|
961
|
+
noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
|
|
926
962
|
|
|
927
963
|
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
928
964
|
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
929
965
|
rot_dic = self._rotation_transform(
|
|
930
|
-
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
966
|
+
image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
|
|
931
967
|
)
|
|
932
968
|
rotated_img_tuples = []
|
|
933
969
|
for i, img in enumerate(img_tuples):
|
|
@@ -1006,7 +1042,10 @@ class MultiChDloader:
|
|
|
1006
1042
|
if self._train_index_switcher is not None:
|
|
1007
1043
|
index = self._get_index_from_valid_target_logic(index)
|
|
1008
1044
|
|
|
1009
|
-
if
|
|
1045
|
+
if (
|
|
1046
|
+
self._uncorrelated_channels
|
|
1047
|
+
and np.random.rand() < self._uncorrelated_channel_probab
|
|
1048
|
+
):
|
|
1010
1049
|
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
1011
1050
|
else:
|
|
1012
1051
|
img_tuples, noise_tuples = self._get_img(index)
|
|
@@ -1042,8 +1081,9 @@ class MultiChDloader:
|
|
|
1042
1081
|
img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
|
|
1043
1082
|
|
|
1044
1083
|
target = self._compute_target(img_tuples, alpha)
|
|
1084
|
+
norm_target = self.normalize_target(target)
|
|
1045
1085
|
|
|
1046
|
-
output = [inp,
|
|
1086
|
+
output = [inp, norm_target]
|
|
1047
1087
|
|
|
1048
1088
|
if self._return_alpha:
|
|
1049
1089
|
output.append(alpha)
|