careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +12 -11
- careamics/config/__init__.py +0 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +0 -2
- careamics/config/configuration_factory.py +112 -42
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +59 -157
- careamics/config/inference_model.py +19 -20
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/nd_flip_model.py +6 -11
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +6 -8
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +14 -45
- careamics/dataset/iterable_dataset.py +13 -68
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +20 -24
- careamics/lightning_module.py +1 -1
- careamics/lightning_prediction_datamodule.py +15 -10
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +2 -1
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""In-memory dataset module."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
import copy
|
|
@@ -8,11 +9,12 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from torch.utils.data import Dataset
|
|
10
11
|
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
11
14
|
from ..config import DataConfig, InferenceConfig
|
|
12
15
|
from ..config.tile_information import TileInformation
|
|
13
16
|
from ..utils.logging import get_logger
|
|
14
17
|
from .dataset_utils import read_tiff, reshape_array
|
|
15
|
-
from .patching.patch_transform import get_patch_transform
|
|
16
18
|
from .patching.patching import (
|
|
17
19
|
prepare_patches_supervised,
|
|
18
20
|
prepare_patches_supervised_array,
|
|
@@ -60,18 +62,15 @@ class InMemoryDataset(Dataset):
|
|
|
60
62
|
self.mean, self.std = computed_mean, computed_std
|
|
61
63
|
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
62
64
|
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
-
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
65
|
+
# update mean and std in configuration
|
|
66
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
+
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
68
68
|
else:
|
|
69
69
|
self.mean, self.std = self.data_config.mean, self.data_config.std
|
|
70
70
|
|
|
71
71
|
# get transforms
|
|
72
|
-
self.patch_transform =
|
|
73
|
-
|
|
74
|
-
with_target=self.data_target is not None,
|
|
72
|
+
self.patch_transform = Compose(
|
|
73
|
+
transform_list=self.data_config.transforms,
|
|
75
74
|
)
|
|
76
75
|
|
|
77
76
|
def _prepare_patches(
|
|
@@ -166,33 +165,10 @@ class InMemoryDataset(Dataset):
|
|
|
166
165
|
# get target
|
|
167
166
|
target = self.data_targets[index]
|
|
168
167
|
|
|
169
|
-
|
|
170
|
-
c_patch = np.moveaxis(patch, 0, -1)
|
|
171
|
-
c_target = np.moveaxis(target, 0, -1)
|
|
172
|
-
|
|
173
|
-
# Apply transforms
|
|
174
|
-
transformed = self.patch_transform(image=c_patch, target=c_target)
|
|
175
|
-
|
|
176
|
-
# move axes back
|
|
177
|
-
patch = np.moveaxis(transformed["image"], -1, 0)
|
|
178
|
-
target = np.moveaxis(transformed["target"], -1, 0)
|
|
179
|
-
|
|
180
|
-
return patch, target
|
|
168
|
+
return self.patch_transform(patch=patch, target=target)
|
|
181
169
|
|
|
182
170
|
elif self.data_config.has_n2v_manipulate():
|
|
183
|
-
|
|
184
|
-
patch = np.moveaxis(patch, 0, -1)
|
|
185
|
-
|
|
186
|
-
# Apply transforms
|
|
187
|
-
transformed_patch = self.patch_transform(image=patch)["image"]
|
|
188
|
-
manip_patch, patch, mask = transformed_patch
|
|
189
|
-
|
|
190
|
-
# move C axes back
|
|
191
|
-
manip_patch = np.moveaxis(manip_patch, -1, 0)
|
|
192
|
-
patch = np.moveaxis(patch, -1, 0)
|
|
193
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
194
|
-
|
|
195
|
-
return (manip_patch, patch, mask)
|
|
171
|
+
return self.patch_transform(patch=patch)
|
|
196
172
|
else:
|
|
197
173
|
raise ValueError(
|
|
198
174
|
"Something went wrong! No target provided (not supervised training) "
|
|
@@ -318,9 +294,8 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
318
294
|
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
319
295
|
|
|
320
296
|
# get transforms
|
|
321
|
-
self.patch_transform =
|
|
322
|
-
|
|
323
|
-
with_target=self.data_target is not None,
|
|
297
|
+
self.patch_transform = Compose(
|
|
298
|
+
transform_list=self.pred_config.transforms,
|
|
324
299
|
)
|
|
325
300
|
|
|
326
301
|
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
@@ -379,13 +354,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
379
354
|
"""
|
|
380
355
|
tile_array, tile_info = self.data[index]
|
|
381
356
|
|
|
382
|
-
# Albumentations requires channel last, use the XArrayTile array
|
|
383
|
-
patch = np.moveaxis(tile_array, 0, -1)
|
|
384
|
-
|
|
385
357
|
# Apply transforms
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
# move C axes back
|
|
389
|
-
transformed_patch = np.moveaxis(transformed_patch, -1, 0)
|
|
358
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
390
359
|
|
|
391
|
-
return
|
|
360
|
+
return transformed_tile, tile_info
|
|
@@ -7,13 +7,12 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
|
|
10
|
+
from careamics.transforms import Compose
|
|
11
|
+
|
|
10
12
|
from ..config import DataConfig, InferenceConfig
|
|
11
13
|
from ..config.tile_information import TileInformation
|
|
12
14
|
from ..utils.logging import get_logger
|
|
13
15
|
from .dataset_utils import read_tiff, reshape_array
|
|
14
|
-
from .patching import (
|
|
15
|
-
get_patch_transform,
|
|
16
|
-
)
|
|
17
16
|
from .patching.random_patching import extract_patches_random
|
|
18
17
|
from .patching.tiled_patching import extract_tiles
|
|
19
18
|
|
|
@@ -61,26 +60,15 @@ class PathIterableDataset(IterableDataset):
|
|
|
61
60
|
if not data_config.mean or not data_config.std:
|
|
62
61
|
self.mean, self.std = self._calculate_mean_and_std()
|
|
63
62
|
|
|
64
|
-
#
|
|
65
|
-
#
|
|
66
|
-
|
|
67
|
-
if hasattr(data_config, "has_transform_list"):
|
|
68
|
-
if data_config.has_transform_list():
|
|
69
|
-
# update mean and std in configuration
|
|
70
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
71
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
72
|
-
else:
|
|
73
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
74
|
-
|
|
63
|
+
# update mean and std in configuration
|
|
64
|
+
# the object is mutable and should then be recorded in the CAREamist
|
|
65
|
+
data_config.set_mean_and_std(self.mean, self.std)
|
|
75
66
|
else:
|
|
76
67
|
self.mean = data_config.mean
|
|
77
68
|
self.std = data_config.std
|
|
78
69
|
|
|
79
70
|
# get transforms
|
|
80
|
-
self.patch_transform =
|
|
81
|
-
patch_transforms=data_config.transforms,
|
|
82
|
-
with_target=target_files is not None,
|
|
83
|
-
)
|
|
71
|
+
self.patch_transform = Compose(transform_list=data_config.transforms)
|
|
84
72
|
|
|
85
73
|
def _calculate_mean_and_std(self) -> Tuple[float, float]:
|
|
86
74
|
"""
|
|
@@ -192,49 +180,10 @@ class PathIterableDataset(IterableDataset):
|
|
|
192
180
|
# or (patch, None) only if no target is available
|
|
193
181
|
# patch is of dimensions (C)ZYX
|
|
194
182
|
for patch_data in patches:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
c_patch = np.moveaxis(patch_data[0], 0, -1)
|
|
200
|
-
c_target = np.moveaxis(patch_data[1], 0, -1)
|
|
201
|
-
|
|
202
|
-
# apply the transform to the patch and the target
|
|
203
|
-
transformed = self.patch_transform(
|
|
204
|
-
image=c_patch,
|
|
205
|
-
target=c_target,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# move the axes back to the original position
|
|
209
|
-
c_patch = np.moveaxis(transformed["image"], -1, 0)
|
|
210
|
-
c_target = np.moveaxis(transformed["target"], -1, 0)
|
|
211
|
-
|
|
212
|
-
yield (c_patch, c_target)
|
|
213
|
-
elif self.data_config.has_n2v_manipulate():
|
|
214
|
-
# Albumentations expects the channel dimension to be last
|
|
215
|
-
# Taking the first element because patch_data can include target
|
|
216
|
-
patch = np.moveaxis(patch_data[0], 0, -1)
|
|
217
|
-
|
|
218
|
-
# apply transform
|
|
219
|
-
transformed = self.patch_transform(image=patch)
|
|
220
|
-
|
|
221
|
-
# retrieve the output of ManipulateN2V
|
|
222
|
-
results = transformed["image"]
|
|
223
|
-
masked_patch: np.ndarray = results[0]
|
|
224
|
-
original_patch: np.ndarray = results[1]
|
|
225
|
-
mask: np.ndarray = results[2]
|
|
226
|
-
|
|
227
|
-
# move C axes back
|
|
228
|
-
masked_patch = np.moveaxis(masked_patch, -1, 0)
|
|
229
|
-
original_patch = np.moveaxis(original_patch, -1, 0)
|
|
230
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
231
|
-
|
|
232
|
-
yield (masked_patch, original_patch, mask)
|
|
233
|
-
else:
|
|
234
|
-
raise ValueError(
|
|
235
|
-
"Something went wrong! Not target file (no supervised "
|
|
236
|
-
"training) and no N2V transform (no n2v training either)."
|
|
237
|
-
)
|
|
183
|
+
yield self.patch_transform(
|
|
184
|
+
patch=patch_data[0],
|
|
185
|
+
target=patch_data[1],
|
|
186
|
+
)
|
|
238
187
|
|
|
239
188
|
def get_number_of_files(self) -> int:
|
|
240
189
|
"""
|
|
@@ -367,9 +316,8 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
367
316
|
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
368
317
|
|
|
369
318
|
# get tta transforms
|
|
370
|
-
self.patch_transform =
|
|
371
|
-
|
|
372
|
-
with_target=False,
|
|
319
|
+
self.patch_transform = Compose(
|
|
320
|
+
transform_list=prediction_config.transforms,
|
|
373
321
|
)
|
|
374
322
|
|
|
375
323
|
def __iter__(
|
|
@@ -408,9 +356,6 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
408
356
|
|
|
409
357
|
# apply transform to patches
|
|
410
358
|
for patch_array, tile_info in patch_gen:
|
|
411
|
-
|
|
412
|
-
patch = np.moveaxis(patch_array, 0, -1)
|
|
413
|
-
transformed_patch = self.patch_transform(image=patch)
|
|
414
|
-
transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
|
|
359
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
415
360
|
|
|
416
361
|
yield transformed_patch, tile_info
|
|
@@ -135,15 +135,12 @@ def _compute_patch_views(
|
|
|
135
135
|
arr = np.stack([arr, target], axis=0)
|
|
136
136
|
window_shape = [arr.shape[0], *window_shape]
|
|
137
137
|
step = (arr.shape[0], *step)
|
|
138
|
-
output_shape = [arr.shape[0],
|
|
138
|
+
output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
|
|
139
139
|
|
|
140
140
|
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
141
141
|
*output_shape
|
|
142
142
|
)
|
|
143
|
-
|
|
144
|
-
rng.shuffle(patches, axis=1)
|
|
145
|
-
else:
|
|
146
|
-
rng.shuffle(patches, axis=0)
|
|
143
|
+
rng.shuffle(patches, axis=0)
|
|
147
144
|
return patches
|
|
148
145
|
|
|
149
146
|
|
|
@@ -201,6 +198,9 @@ def extract_patches_sequential(
|
|
|
201
198
|
|
|
202
199
|
if target is not None:
|
|
203
200
|
# target was concatenated to patches in _compute_reshaped_view
|
|
204
|
-
return (
|
|
201
|
+
return (
|
|
202
|
+
patches[:, 0, ...],
|
|
203
|
+
patches[:, 1, ...],
|
|
204
|
+
) # TODO in _compute_reshaped_view?
|
|
205
205
|
else:
|
|
206
206
|
return patches, None
|
|
@@ -43,9 +43,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
43
43
|
stitch_coords.append(
|
|
44
44
|
(
|
|
45
45
|
i + overlap // 2 if i > 0 else 0,
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
(
|
|
47
|
+
i + tile_size - overlap // 2
|
|
48
|
+
if crop_coords[-1][1] < axis_size
|
|
49
|
+
else axis_size
|
|
50
|
+
),
|
|
49
51
|
)
|
|
50
52
|
)
|
|
51
53
|
|
|
@@ -53,9 +55,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
53
55
|
overlap_crop_coords.append(
|
|
54
56
|
(
|
|
55
57
|
overlap // 2 if i > 0 else 0,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
(
|
|
59
|
+
tile_size - overlap // 2
|
|
60
|
+
if crop_coords[-1][1] < axis_size
|
|
61
|
+
else tile_size
|
|
62
|
+
),
|
|
59
63
|
)
|
|
60
64
|
)
|
|
61
65
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Training and validation Lightning data modules."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pytorch_lightning as L
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
9
|
|
|
10
10
|
from careamics.config import DataConfig
|
|
@@ -341,9 +341,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
341
341
|
self.train_dataset = InMemoryDataset(
|
|
342
342
|
data_config=self.data_config,
|
|
343
343
|
inputs=self.train_files,
|
|
344
|
-
data_target=
|
|
345
|
-
|
|
346
|
-
|
|
344
|
+
data_target=(
|
|
345
|
+
self.train_target_files if self.train_data_target else None
|
|
346
|
+
),
|
|
347
347
|
read_source_func=self.read_source_func,
|
|
348
348
|
)
|
|
349
349
|
|
|
@@ -352,9 +352,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
352
352
|
self.val_dataset = InMemoryDataset(
|
|
353
353
|
data_config=self.data_config,
|
|
354
354
|
inputs=self.val_files,
|
|
355
|
-
data_target=
|
|
356
|
-
|
|
357
|
-
|
|
355
|
+
data_target=(
|
|
356
|
+
self.val_target_files if self.val_data_target else None
|
|
357
|
+
),
|
|
358
358
|
read_source_func=self.read_source_func,
|
|
359
359
|
)
|
|
360
360
|
else:
|
|
@@ -370,9 +370,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
370
370
|
self.train_dataset = PathIterableDataset(
|
|
371
371
|
data_config=self.data_config,
|
|
372
372
|
src_files=self.train_files,
|
|
373
|
-
target_files=
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
target_files=(
|
|
374
|
+
self.train_target_files if self.train_data_target else None
|
|
375
|
+
),
|
|
376
376
|
read_source_func=self.read_source_func,
|
|
377
377
|
)
|
|
378
378
|
|
|
@@ -382,9 +382,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
382
382
|
self.val_dataset = PathIterableDataset(
|
|
383
383
|
data_config=self.data_config,
|
|
384
384
|
src_files=self.val_files,
|
|
385
|
-
target_files=
|
|
386
|
-
|
|
387
|
-
|
|
385
|
+
target_files=(
|
|
386
|
+
self.val_target_files if self.val_data_target else None
|
|
387
|
+
),
|
|
388
388
|
read_source_func=self.read_source_func,
|
|
389
389
|
)
|
|
390
390
|
elif len(self.train_files) <= self.val_minimum_split:
|
|
@@ -452,8 +452,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
452
452
|
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
453
453
|
not compatible with supervised training. The default transformations applied to the
|
|
454
454
|
training patches are defined in `careamics.config.data_model`. To use different
|
|
455
|
-
transformations, pass a list of transforms
|
|
456
|
-
`transforms` parameter. See examples for more details.
|
|
455
|
+
transformations, pass a list of transforms. See examples for more details.
|
|
457
456
|
|
|
458
457
|
By default, CAREamics only supports types defined in
|
|
459
458
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -488,7 +487,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
488
487
|
Batch size.
|
|
489
488
|
val_data : Optional[Union[str, Path]], optional
|
|
490
489
|
Validation data, by default None.
|
|
491
|
-
transforms :
|
|
490
|
+
transforms : List[TRANSFORMS_UNION], optional
|
|
492
491
|
List of transforms to apply to training patches. If None, default transforms
|
|
493
492
|
are applied.
|
|
494
493
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -584,7 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
584
583
|
axes: str,
|
|
585
584
|
batch_size: int,
|
|
586
585
|
val_data: Optional[Union[str, Path]] = None,
|
|
587
|
-
transforms: Optional[
|
|
586
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
588
587
|
train_target_data: Optional[Union[str, Path]] = None,
|
|
589
588
|
val_target_data: Optional[Union[str, Path]] = None,
|
|
590
589
|
read_source_func: Optional[Callable] = None,
|
|
@@ -617,8 +616,8 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
617
616
|
In particular, N2V requires a specific transformation (N2V manipulates), which
|
|
618
617
|
is not compatible with supervised training. The default transformations applied
|
|
619
618
|
to the training patches are defined in `careamics.config.data_model`. To use
|
|
620
|
-
different transformations, pass a list of transforms
|
|
621
|
-
|
|
619
|
+
different transformations, pass a list of transforms. See examples for more
|
|
620
|
+
details.
|
|
622
621
|
|
|
623
622
|
By default, CAREamics only supports types defined in
|
|
624
623
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -655,7 +654,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
655
654
|
Batch size.
|
|
656
655
|
val_data : Optional[Union[str, Path]], optional
|
|
657
656
|
Validation data, by default None.
|
|
658
|
-
transforms : Optional[
|
|
657
|
+
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
659
658
|
List of transforms to apply to training patches. If None, default transforms
|
|
660
659
|
are applied.
|
|
661
660
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -709,10 +708,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
709
708
|
self.data_config = DataConfig(**data_dict)
|
|
710
709
|
|
|
711
710
|
# N2V specific checks, N2V, structN2V, and transforms
|
|
712
|
-
if (
|
|
713
|
-
self.data_config.has_transform_list()
|
|
714
|
-
and self.data_config.has_n2v_manipulate()
|
|
715
|
-
):
|
|
711
|
+
if self.data_config.has_n2v_manipulate():
|
|
716
712
|
# there is not target, n2v2 and structN2V can be changed
|
|
717
713
|
if train_target_data is None:
|
|
718
714
|
self.data_config.set_N2V2(use_n2v2)
|
careamics/lightning_module.py
CHANGED
|
@@ -162,7 +162,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
162
162
|
mean=self._trainer.datamodule.predict_dataset.mean,
|
|
163
163
|
std=self._trainer.datamodule.predict_dataset.std,
|
|
164
164
|
)
|
|
165
|
-
denormalized_output = denorm(
|
|
165
|
+
denormalized_output, _ = denorm(patch=output)
|
|
166
166
|
|
|
167
167
|
if len(aux) > 0:
|
|
168
168
|
return denormalized_output, aux
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pytorch_lightning as L
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
9
|
from torch.utils.data.dataloader import default_collate
|
|
10
10
|
|
|
@@ -39,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
39
39
|
|
|
40
40
|
Parameters
|
|
41
41
|
----------
|
|
42
|
-
batch :
|
|
42
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
43
43
|
Batch of tiles.
|
|
44
44
|
|
|
45
45
|
Returns
|
|
@@ -257,14 +257,13 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
257
257
|
|
|
258
258
|
The default transformations applied to the images are defined in
|
|
259
259
|
`careamics.config.inference_model`. To use different transformations, pass a list
|
|
260
|
-
of transforms
|
|
260
|
+
of transforms. See examples
|
|
261
261
|
for more details.
|
|
262
262
|
|
|
263
263
|
The `mean` and `std` parameters are only used if Normalization is defined either
|
|
264
|
-
in the default transformations or in the `transforms` parameter
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
to this method.
|
|
264
|
+
in the default transformations or in the `transforms` parameter. If you pass a
|
|
265
|
+
`Normalization` transform in a list as `transforms`, then the mean and std
|
|
266
|
+
parameters will be overwritten by those passed to this method.
|
|
268
267
|
|
|
269
268
|
By default, CAREamics only supports types defined in
|
|
270
269
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -276,6 +275,12 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
276
275
|
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
277
276
|
parameter.
|
|
278
277
|
|
|
278
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
279
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
280
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
281
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
282
|
+
happen in the Z dimension, consider padding your image.
|
|
283
|
+
|
|
279
284
|
Parameters
|
|
280
285
|
----------
|
|
281
286
|
pred_data : Union[str, Path, np.ndarray]
|
|
@@ -298,7 +303,7 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
298
303
|
Batch size.
|
|
299
304
|
tta_transforms : bool, optional
|
|
300
305
|
Use test time augmentation, by default True.
|
|
301
|
-
transforms :
|
|
306
|
+
transforms : List, optional
|
|
302
307
|
List of transforms to apply to prediction patches. If None, default
|
|
303
308
|
transforms are applied.
|
|
304
309
|
read_source_func : Optional[Callable], optional
|
|
@@ -321,7 +326,7 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
321
326
|
axes: str = "YX",
|
|
322
327
|
batch_size: int = 1,
|
|
323
328
|
tta_transforms: bool = True,
|
|
324
|
-
transforms: Optional[
|
|
329
|
+
transforms: Optional[List] = None,
|
|
325
330
|
read_source_func: Optional[Callable] = None,
|
|
326
331
|
extension_filter: str = "",
|
|
327
332
|
dataloader_params: Optional[dict] = None,
|
|
@@ -351,7 +356,7 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
351
356
|
Batch size.
|
|
352
357
|
tta_transforms : bool, optional
|
|
353
358
|
Use test time augmentation, by default True.
|
|
354
|
-
transforms : Optional[
|
|
359
|
+
transforms : Optional[List], optional
|
|
355
360
|
List of transforms to apply to prediction patches. If None, default
|
|
356
361
|
transforms are applied.
|
|
357
362
|
read_source_func : Optional[Callable], optional
|
careamics/losses/__init__.py
CHANGED
careamics/losses/loss_factory.py
CHANGED
careamics/model_io/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Functions used to create a README.md file for BMZ export."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Optional
|
|
4
5
|
|
|
@@ -117,4 +118,4 @@ def readme_factory(
|
|
|
117
118
|
|
|
118
119
|
readme.write_text("".join(description))
|
|
119
120
|
|
|
120
|
-
|
|
121
|
+
return readme.absolute()
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Function to export to the BioImage Model Zoo format."""
|
|
2
|
+
|
|
2
3
|
import tempfile
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import List, Optional, Tuple, Union
|
|
@@ -177,7 +178,7 @@ def export_to_bmz(
|
|
|
177
178
|
)
|
|
178
179
|
|
|
179
180
|
# test model description
|
|
180
|
-
summary: ValidationSummary = test_model(model_description)
|
|
181
|
+
summary: ValidationSummary = test_model(model_description, decimal=0)
|
|
181
182
|
if summary.status == "failed":
|
|
182
183
|
raise ValueError(f"Model description test failed: {summary}")
|
|
183
184
|
|
careamics/models/layers.py
CHANGED