careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +30 -29
- careamics/config/__init__.py +12 -9
- careamics/config/algorithm_model.py +5 -5
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +87 -0
- careamics/config/configuration_factory.py +285 -78
- careamics/config/configuration_model.py +22 -23
- careamics/config/data_model.py +62 -160
- careamics/config/inference_model.py +20 -21
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +2 -1
- careamics/config/transformations/n2v_manipulate_model.py +2 -1
- careamics/config/transformations/nd_flip_model.py +7 -12
- careamics/config/transformations/normalize_model.py +2 -1
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +7 -9
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +17 -48
- careamics/dataset/iterable_dataset.py +16 -71
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +123 -49
- careamics/lightning_module.py +7 -7
- careamics/lightning_prediction_datamodule.py +59 -48
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +4 -3
- careamics/model_io/bmz_io.py +8 -7
- careamics/model_io/model_io_utils.py +4 -4
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc3.dist-info/RECORD +0 -109
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""In-memory dataset module."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
import copy
|
|
@@ -8,11 +9,12 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from torch.utils.data import Dataset
|
|
10
11
|
|
|
11
|
-
from
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
14
|
+
from ..config import DataConfig, InferenceConfig
|
|
12
15
|
from ..config.tile_information import TileInformation
|
|
13
16
|
from ..utils.logging import get_logger
|
|
14
17
|
from .dataset_utils import read_tiff, reshape_array
|
|
15
|
-
from .patching.patch_transform import get_patch_transform
|
|
16
18
|
from .patching.patching import (
|
|
17
19
|
prepare_patches_supervised,
|
|
18
20
|
prepare_patches_supervised_array,
|
|
@@ -29,7 +31,7 @@ class InMemoryDataset(Dataset):
|
|
|
29
31
|
|
|
30
32
|
def __init__(
|
|
31
33
|
self,
|
|
32
|
-
data_config:
|
|
34
|
+
data_config: DataConfig,
|
|
33
35
|
inputs: Union[np.ndarray, List[Path]],
|
|
34
36
|
data_target: Optional[Union[np.ndarray, List[Path]]] = None,
|
|
35
37
|
read_source_func: Callable = read_tiff,
|
|
@@ -60,18 +62,15 @@ class InMemoryDataset(Dataset):
|
|
|
60
62
|
self.mean, self.std = computed_mean, computed_std
|
|
61
63
|
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
62
64
|
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
-
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
65
|
+
# update mean and std in configuration
|
|
66
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
+
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
68
68
|
else:
|
|
69
69
|
self.mean, self.std = self.data_config.mean, self.data_config.std
|
|
70
70
|
|
|
71
71
|
# get transforms
|
|
72
|
-
self.patch_transform =
|
|
73
|
-
|
|
74
|
-
with_target=self.data_target is not None,
|
|
72
|
+
self.patch_transform = Compose(
|
|
73
|
+
transform_list=self.data_config.transforms,
|
|
75
74
|
)
|
|
76
75
|
|
|
77
76
|
def _prepare_patches(
|
|
@@ -166,33 +165,10 @@ class InMemoryDataset(Dataset):
|
|
|
166
165
|
# get target
|
|
167
166
|
target = self.data_targets[index]
|
|
168
167
|
|
|
169
|
-
|
|
170
|
-
c_patch = np.moveaxis(patch, 0, -1)
|
|
171
|
-
c_target = np.moveaxis(target, 0, -1)
|
|
172
|
-
|
|
173
|
-
# Apply transforms
|
|
174
|
-
transformed = self.patch_transform(image=c_patch, target=c_target)
|
|
175
|
-
|
|
176
|
-
# move axes back
|
|
177
|
-
patch = np.moveaxis(transformed["image"], -1, 0)
|
|
178
|
-
target = np.moveaxis(transformed["target"], -1, 0)
|
|
179
|
-
|
|
180
|
-
return patch, target
|
|
168
|
+
return self.patch_transform(patch=patch, target=target)
|
|
181
169
|
|
|
182
170
|
elif self.data_config.has_n2v_manipulate():
|
|
183
|
-
|
|
184
|
-
patch = np.moveaxis(patch, 0, -1)
|
|
185
|
-
|
|
186
|
-
# Apply transforms
|
|
187
|
-
transformed_patch = self.patch_transform(image=patch)["image"]
|
|
188
|
-
manip_patch, patch, mask = transformed_patch
|
|
189
|
-
|
|
190
|
-
# move C axes back
|
|
191
|
-
manip_patch = np.moveaxis(manip_patch, -1, 0)
|
|
192
|
-
patch = np.moveaxis(patch, -1, 0)
|
|
193
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
194
|
-
|
|
195
|
-
return (manip_patch, patch, mask)
|
|
171
|
+
return self.patch_transform(patch=patch)
|
|
196
172
|
else:
|
|
197
173
|
raise ValueError(
|
|
198
174
|
"Something went wrong! No target provided (not supervised training) "
|
|
@@ -279,7 +255,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
279
255
|
|
|
280
256
|
def __init__(
|
|
281
257
|
self,
|
|
282
|
-
prediction_config:
|
|
258
|
+
prediction_config: InferenceConfig,
|
|
283
259
|
inputs: np.ndarray,
|
|
284
260
|
data_target: Optional[np.ndarray] = None,
|
|
285
261
|
read_source_func: Optional[Callable] = read_tiff,
|
|
@@ -318,9 +294,8 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
318
294
|
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
319
295
|
|
|
320
296
|
# get transforms
|
|
321
|
-
self.patch_transform =
|
|
322
|
-
|
|
323
|
-
with_target=self.data_target is not None,
|
|
297
|
+
self.patch_transform = Compose(
|
|
298
|
+
transform_list=self.pred_config.transforms,
|
|
324
299
|
)
|
|
325
300
|
|
|
326
301
|
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
@@ -379,13 +354,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
379
354
|
"""
|
|
380
355
|
tile_array, tile_info = self.data[index]
|
|
381
356
|
|
|
382
|
-
# Albumentations requires channel last, use the XArrayTile array
|
|
383
|
-
patch = np.moveaxis(tile_array, 0, -1)
|
|
384
|
-
|
|
385
357
|
# Apply transforms
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
# move C axes back
|
|
389
|
-
transformed_patch = np.moveaxis(transformed_patch, -1, 0)
|
|
358
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
390
359
|
|
|
391
|
-
return
|
|
360
|
+
return transformed_tile, tile_info
|
|
@@ -7,13 +7,12 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from careamics.transforms import Compose
|
|
11
|
+
|
|
12
|
+
from ..config import DataConfig, InferenceConfig
|
|
11
13
|
from ..config.tile_information import TileInformation
|
|
12
14
|
from ..utils.logging import get_logger
|
|
13
15
|
from .dataset_utils import read_tiff, reshape_array
|
|
14
|
-
from .patching import (
|
|
15
|
-
get_patch_transform,
|
|
16
|
-
)
|
|
17
16
|
from .patching.random_patching import extract_patches_random
|
|
18
17
|
from .patching.tiled_patching import extract_tiles
|
|
19
18
|
|
|
@@ -46,7 +45,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
46
45
|
|
|
47
46
|
def __init__(
|
|
48
47
|
self,
|
|
49
|
-
data_config: Union[
|
|
48
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
50
49
|
src_files: List[Path],
|
|
51
50
|
target_files: Optional[List[Path]] = None,
|
|
52
51
|
read_source_func: Callable = read_tiff,
|
|
@@ -61,26 +60,15 @@ class PathIterableDataset(IterableDataset):
|
|
|
61
60
|
if not data_config.mean or not data_config.std:
|
|
62
61
|
self.mean, self.std = self._calculate_mean_and_std()
|
|
63
62
|
|
|
64
|
-
#
|
|
65
|
-
#
|
|
66
|
-
|
|
67
|
-
if hasattr(data_config, "has_transform_list"):
|
|
68
|
-
if data_config.has_transform_list():
|
|
69
|
-
# update mean and std in configuration
|
|
70
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
71
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
72
|
-
else:
|
|
73
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
74
|
-
|
|
63
|
+
# update mean and std in configuration
|
|
64
|
+
# the object is mutable and should then be recorded in the CAREamist
|
|
65
|
+
data_config.set_mean_and_std(self.mean, self.std)
|
|
75
66
|
else:
|
|
76
67
|
self.mean = data_config.mean
|
|
77
68
|
self.std = data_config.std
|
|
78
69
|
|
|
79
70
|
# get transforms
|
|
80
|
-
self.patch_transform =
|
|
81
|
-
patch_transforms=data_config.transforms,
|
|
82
|
-
with_target=target_files is not None,
|
|
83
|
-
)
|
|
71
|
+
self.patch_transform = Compose(transform_list=data_config.transforms)
|
|
84
72
|
|
|
85
73
|
def _calculate_mean_and_std(self) -> Tuple[float, float]:
|
|
86
74
|
"""
|
|
@@ -192,49 +180,10 @@ class PathIterableDataset(IterableDataset):
|
|
|
192
180
|
# or (patch, None) only if no target is available
|
|
193
181
|
# patch is of dimensions (C)ZYX
|
|
194
182
|
for patch_data in patches:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
c_patch = np.moveaxis(patch_data[0], 0, -1)
|
|
200
|
-
c_target = np.moveaxis(patch_data[1], 0, -1)
|
|
201
|
-
|
|
202
|
-
# apply the transform to the patch and the target
|
|
203
|
-
transformed = self.patch_transform(
|
|
204
|
-
image=c_patch,
|
|
205
|
-
target=c_target,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# move the axes back to the original position
|
|
209
|
-
c_patch = np.moveaxis(transformed["image"], -1, 0)
|
|
210
|
-
c_target = np.moveaxis(transformed["target"], -1, 0)
|
|
211
|
-
|
|
212
|
-
yield (c_patch, c_target)
|
|
213
|
-
elif self.data_config.has_n2v_manipulate():
|
|
214
|
-
# Albumentations expects the channel dimension to be last
|
|
215
|
-
# Taking the first element because patch_data can include target
|
|
216
|
-
patch = np.moveaxis(patch_data[0], 0, -1)
|
|
217
|
-
|
|
218
|
-
# apply transform
|
|
219
|
-
transformed = self.patch_transform(image=patch)
|
|
220
|
-
|
|
221
|
-
# retrieve the output of ManipulateN2V
|
|
222
|
-
results = transformed["image"]
|
|
223
|
-
masked_patch: np.ndarray = results[0]
|
|
224
|
-
original_patch: np.ndarray = results[1]
|
|
225
|
-
mask: np.ndarray = results[2]
|
|
226
|
-
|
|
227
|
-
# move C axes back
|
|
228
|
-
masked_patch = np.moveaxis(masked_patch, -1, 0)
|
|
229
|
-
original_patch = np.moveaxis(original_patch, -1, 0)
|
|
230
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
231
|
-
|
|
232
|
-
yield (masked_patch, original_patch, mask)
|
|
233
|
-
else:
|
|
234
|
-
raise ValueError(
|
|
235
|
-
"Something went wrong! Not target file (no supervised "
|
|
236
|
-
"training) and no N2V transform (no n2v training either)."
|
|
237
|
-
)
|
|
183
|
+
yield self.patch_transform(
|
|
184
|
+
patch=patch_data[0],
|
|
185
|
+
target=patch_data[1],
|
|
186
|
+
)
|
|
238
187
|
|
|
239
188
|
def get_number_of_files(self) -> int:
|
|
240
189
|
"""
|
|
@@ -346,7 +295,7 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
346
295
|
|
|
347
296
|
def __init__(
|
|
348
297
|
self,
|
|
349
|
-
prediction_config:
|
|
298
|
+
prediction_config: InferenceConfig,
|
|
350
299
|
src_files: List[Path],
|
|
351
300
|
read_source_func: Callable = read_tiff,
|
|
352
301
|
**kwargs: Any,
|
|
@@ -367,9 +316,8 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
367
316
|
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
368
317
|
|
|
369
318
|
# get tta transforms
|
|
370
|
-
self.patch_transform =
|
|
371
|
-
|
|
372
|
-
with_target=False,
|
|
319
|
+
self.patch_transform = Compose(
|
|
320
|
+
transform_list=prediction_config.transforms,
|
|
373
321
|
)
|
|
374
322
|
|
|
375
323
|
def __iter__(
|
|
@@ -408,9 +356,6 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
408
356
|
|
|
409
357
|
# apply transform to patches
|
|
410
358
|
for patch_array, tile_info in patch_gen:
|
|
411
|
-
|
|
412
|
-
patch = np.moveaxis(patch_array, 0, -1)
|
|
413
|
-
transformed_patch = self.patch_transform(image=patch)
|
|
414
|
-
transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
|
|
359
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
415
360
|
|
|
416
361
|
yield transformed_patch, tile_info
|
|
@@ -135,15 +135,12 @@ def _compute_patch_views(
|
|
|
135
135
|
arr = np.stack([arr, target], axis=0)
|
|
136
136
|
window_shape = [arr.shape[0], *window_shape]
|
|
137
137
|
step = (arr.shape[0], *step)
|
|
138
|
-
output_shape = [arr.shape[0],
|
|
138
|
+
output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
|
|
139
139
|
|
|
140
140
|
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
141
141
|
*output_shape
|
|
142
142
|
)
|
|
143
|
-
|
|
144
|
-
rng.shuffle(patches, axis=1)
|
|
145
|
-
else:
|
|
146
|
-
rng.shuffle(patches, axis=0)
|
|
143
|
+
rng.shuffle(patches, axis=0)
|
|
147
144
|
return patches
|
|
148
145
|
|
|
149
146
|
|
|
@@ -201,6 +198,9 @@ def extract_patches_sequential(
|
|
|
201
198
|
|
|
202
199
|
if target is not None:
|
|
203
200
|
# target was concatenated to patches in _compute_reshaped_view
|
|
204
|
-
return (
|
|
201
|
+
return (
|
|
202
|
+
patches[:, 0, ...],
|
|
203
|
+
patches[:, 1, ...],
|
|
204
|
+
) # TODO in _compute_reshaped_view?
|
|
205
205
|
else:
|
|
206
206
|
return patches, None
|
|
@@ -43,9 +43,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
43
43
|
stitch_coords.append(
|
|
44
44
|
(
|
|
45
45
|
i + overlap // 2 if i > 0 else 0,
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
(
|
|
47
|
+
i + tile_size - overlap // 2
|
|
48
|
+
if crop_coords[-1][1] < axis_size
|
|
49
|
+
else axis_size
|
|
50
|
+
),
|
|
49
51
|
)
|
|
50
52
|
)
|
|
51
53
|
|
|
@@ -53,9 +55,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
53
55
|
overlap_crop_coords.append(
|
|
54
56
|
(
|
|
55
57
|
overlap // 2 if i > 0 else 0,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
(
|
|
59
|
+
tile_size - overlap // 2
|
|
60
|
+
if crop_coords[-1][1] < axis_size
|
|
61
|
+
else tile_size
|
|
62
|
+
),
|
|
59
63
|
)
|
|
60
64
|
)
|
|
61
65
|
|