careamics 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +0 -4
- careamics/careamist.py +0 -1
- careamics/config/__init__.py +1 -13
- careamics/config/algorithms/care_algorithm_model.py +84 -0
- careamics/config/algorithms/n2n_algorithm_model.py +85 -0
- careamics/config/algorithms/n2v_algorithm_model.py +269 -1
- careamics/config/configuration.py +21 -13
- careamics/config/configuration_factories.py +179 -187
- careamics/config/configuration_io.py +2 -2
- careamics/config/data/__init__.py +1 -4
- careamics/config/data/data_model.py +46 -62
- careamics/config/support/supported_transforms.py +1 -1
- careamics/config/transformations/__init__.py +0 -2
- careamics/config/transformations/n2v_manipulate_model.py +15 -0
- careamics/config/transformations/transform_unions.py +0 -13
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +3 -10
- careamics/dataset/in_memory_pred_dataset.py +3 -5
- careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +3 -5
- careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
- careamics/dataset_ng/dataset/__init__.py +3 -0
- careamics/dataset_ng/dataset/dataset.py +184 -0
- careamics/dataset_ng/demo_dataset.ipynb +271 -0
- careamics/dataset_ng/demo_patch_extractor.py +53 -0
- careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
- careamics/dataset_ng/patch_extractor/__init__.py +10 -0
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
- careamics/dataset_ng/patching_strategies/__init__.py +11 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
- careamics/lightning/lightning_module.py +78 -27
- careamics/lightning/train_data_module.py +8 -39
- careamics/losses/fcn/losses.py +17 -10
- careamics/lvae_training/eval_utils.py +21 -8
- careamics/model_io/bioimage/bioimage_utils.py +5 -3
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +2 -2
- careamics/transforms/__init__.py +2 -1
- careamics/transforms/compose.py +5 -15
- careamics/transforms/n2v_manipulate_torch.py +143 -0
- careamics/transforms/pixel_manipulation.py +1 -0
- careamics/transforms/pixel_manipulation_torch.py +418 -0
- careamics/utils/version.py +38 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/RECORD +59 -42
- careamics/config/care_configuration.py +0 -100
- careamics/config/data/n2v_data_model.py +0 -193
- careamics/config/n2n_configuration.py +0 -101
- careamics/config/n2v_configuration.py +0 -266
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,7 +8,7 @@ import pytorch_lightning as L
|
|
|
8
8
|
from numpy.typing import NDArray
|
|
9
9
|
from torch.utils.data import DataLoader, IterableDataset
|
|
10
10
|
|
|
11
|
-
from careamics.config.data import DataConfig
|
|
11
|
+
from careamics.config.data import DataConfig
|
|
12
12
|
from careamics.config.support import SupportedData
|
|
13
13
|
from careamics.config.transformations import TransformModel
|
|
14
14
|
from careamics.dataset.dataset_utils import (
|
|
@@ -118,7 +118,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
118
118
|
|
|
119
119
|
def __init__(
|
|
120
120
|
self,
|
|
121
|
-
data_config:
|
|
121
|
+
data_config: DataConfig,
|
|
122
122
|
train_data: Union[Path, str, NDArray],
|
|
123
123
|
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
124
124
|
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
@@ -218,7 +218,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
218
218
|
)
|
|
219
219
|
|
|
220
220
|
# configuration
|
|
221
|
-
self.data_config:
|
|
221
|
+
self.data_config: DataConfig = data_config
|
|
222
222
|
self.data_type: str = data_config.data_type
|
|
223
223
|
self.batch_size: int = data_config.batch_size
|
|
224
224
|
self.use_in_memory: bool = use_in_memory
|
|
@@ -486,9 +486,6 @@ def create_train_datamodule(
|
|
|
486
486
|
val_minimum_patches: int = 5,
|
|
487
487
|
dataloader_params: Optional[dict] = None,
|
|
488
488
|
use_in_memory: bool = True,
|
|
489
|
-
use_n2v2: bool = False,
|
|
490
|
-
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
491
|
-
struct_n2v_span: int = 5,
|
|
492
489
|
) -> TrainDataModule:
|
|
493
490
|
"""Create a TrainDataModule.
|
|
494
491
|
|
|
@@ -499,16 +496,8 @@ def create_train_datamodule(
|
|
|
499
496
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
500
497
|
are coherent.
|
|
501
498
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
|
|
505
|
-
pixel manipulation. If you pass a training target data, the default behaviour is to
|
|
506
|
-
train a supervised model. It will use the default XY flip and rotation
|
|
507
|
-
augmentations.
|
|
508
|
-
|
|
509
|
-
To use a different set of transformations, you can pass a list of transforms to
|
|
510
|
-
`transforms`. Note that if you intend to use Noise2Void, you should add
|
|
511
|
-
`N2VManipulateModel` as the last transform in the list of transformations.
|
|
499
|
+
The default augmentations are XY flip and XY rotation. To use a different set of
|
|
500
|
+
transformations, you can pass a list of transforms to `transforms`.
|
|
512
501
|
|
|
513
502
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
514
503
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
@@ -534,11 +523,6 @@ def create_train_datamodule(
|
|
|
534
523
|
In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
|
|
535
524
|
except for `batch_size`, which is set by the `batch_size` parameter.
|
|
536
525
|
|
|
537
|
-
Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
|
|
538
|
-
use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
|
|
539
|
-
the axis and span of the structN2V mask. These parameters are without effect if
|
|
540
|
-
a `train_target_data` or if `transforms` are provided.
|
|
541
|
-
|
|
542
526
|
Parameters
|
|
543
527
|
----------
|
|
544
528
|
train_data : pathlib.Path or str or numpy.ndarray
|
|
@@ -575,13 +559,6 @@ def create_train_datamodule(
|
|
|
575
559
|
Pytorch dataloader parameters, by default {}.
|
|
576
560
|
use_in_memory : bool, optional
|
|
577
561
|
Use in memory dataset if possible, by default True.
|
|
578
|
-
use_n2v2 : bool, optional
|
|
579
|
-
Use N2V2 transformation during training, by default False.
|
|
580
|
-
struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
|
|
581
|
-
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
582
|
-
default "none".
|
|
583
|
-
struct_n2v_span : int, optional
|
|
584
|
-
Span for the structN2V mask, by default 5.
|
|
585
562
|
|
|
586
563
|
Returns
|
|
587
564
|
-------
|
|
@@ -624,12 +601,11 @@ def create_train_datamodule(
|
|
|
624
601
|
transforms:
|
|
625
602
|
>>> import numpy as np
|
|
626
603
|
>>> from careamics.lightning import create_train_datamodule
|
|
627
|
-
>>> from careamics.config.transformations import XYFlipModel
|
|
604
|
+
>>> from careamics.config.transformations import XYFlipModel
|
|
628
605
|
>>> from careamics.config.support import SupportedTransform
|
|
629
606
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
630
607
|
>>> my_transforms = [
|
|
631
608
|
... XYFlipModel(flip_y=False),
|
|
632
|
-
... N2VManipulateModel()
|
|
633
609
|
... ]
|
|
634
610
|
>>> data_module = create_train_datamodule(
|
|
635
611
|
... train_data=my_array,
|
|
@@ -656,15 +632,8 @@ def create_train_datamodule(
|
|
|
656
632
|
if transforms is not None:
|
|
657
633
|
data_dict["transforms"] = transforms
|
|
658
634
|
|
|
659
|
-
#
|
|
660
|
-
|
|
661
|
-
data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
|
|
662
|
-
assert isinstance(data_config, N2VDataConfig)
|
|
663
|
-
|
|
664
|
-
data_config.set_n2v2(use_n2v2)
|
|
665
|
-
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
666
|
-
else:
|
|
667
|
-
data_config = DataConfig(**data_dict)
|
|
635
|
+
# instantiate data configuration
|
|
636
|
+
data_config = DataConfig(**data_dict)
|
|
668
637
|
|
|
669
638
|
# sanity check on the dataloader parameters
|
|
670
639
|
if "batch_size" in dataloader_params:
|
careamics/losses/fcn/losses.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
|
8
8
|
from torch.nn import L1Loss, MSELoss
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor, *args) -> torch.Tensor:
|
|
12
12
|
"""
|
|
13
13
|
Mean squared error loss.
|
|
14
14
|
|
|
@@ -18,6 +18,8 @@ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
18
18
|
Source patches.
|
|
19
19
|
target : torch.Tensor
|
|
20
20
|
Target patches.
|
|
21
|
+
*args : Any
|
|
22
|
+
Additional arguments.
|
|
21
23
|
|
|
22
24
|
Returns
|
|
23
25
|
-------
|
|
@@ -29,34 +31,37 @@ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
def n2v_loss(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
+
manipulated_batch: torch.Tensor,
|
|
35
|
+
original_batch: torch.Tensor,
|
|
34
36
|
masks: torch.Tensor,
|
|
37
|
+
*args,
|
|
35
38
|
) -> torch.Tensor:
|
|
36
39
|
"""
|
|
37
40
|
N2V Loss function described in A Krull et al 2018.
|
|
38
41
|
|
|
39
42
|
Parameters
|
|
40
43
|
----------
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
manipulated_batch : torch.Tensor
|
|
45
|
+
Batch after manipulation function applied.
|
|
46
|
+
original_batch : torch.Tensor
|
|
47
|
+
Original images.
|
|
45
48
|
masks : torch.Tensor
|
|
46
|
-
|
|
49
|
+
Coordinates of changed pixels.
|
|
50
|
+
*args : Any
|
|
51
|
+
Additional arguments.
|
|
47
52
|
|
|
48
53
|
Returns
|
|
49
54
|
-------
|
|
50
55
|
torch.Tensor
|
|
51
56
|
Loss value.
|
|
52
57
|
"""
|
|
53
|
-
errors = (
|
|
58
|
+
errors = (original_batch - manipulated_batch) ** 2
|
|
54
59
|
# Average over pixels and batch
|
|
55
60
|
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
56
61
|
return loss # TODO change output to dict ?
|
|
57
62
|
|
|
58
63
|
|
|
59
|
-
def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
def mae_loss(samples: torch.Tensor, labels: torch.Tensor, *args) -> torch.Tensor:
|
|
60
65
|
"""
|
|
61
66
|
N2N Loss function described in to J Lehtinen et al 2018.
|
|
62
67
|
|
|
@@ -66,6 +71,8 @@ def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
|
66
71
|
Raw patches.
|
|
67
72
|
labels : torch.Tensor
|
|
68
73
|
Different subset of noisy patches.
|
|
74
|
+
*args : Any
|
|
75
|
+
Additional arguments.
|
|
69
76
|
|
|
70
77
|
Returns
|
|
71
78
|
-------
|
|
@@ -107,6 +107,15 @@ def get_first_index(bin_count, quantile):
|
|
|
107
107
|
return None
|
|
108
108
|
|
|
109
109
|
|
|
110
|
+
def get_device():
|
|
111
|
+
if torch.cuda.is_available():
|
|
112
|
+
return "cuda"
|
|
113
|
+
elif torch.backends.mps.is_available():
|
|
114
|
+
return "mps"
|
|
115
|
+
else:
|
|
116
|
+
return "cpu"
|
|
117
|
+
|
|
118
|
+
|
|
110
119
|
def show_for_one(
|
|
111
120
|
idx,
|
|
112
121
|
val_dset,
|
|
@@ -516,7 +525,7 @@ def get_predictions(
|
|
|
516
525
|
num_workers=num_workers,
|
|
517
526
|
)
|
|
518
527
|
# get filename without extension and path
|
|
519
|
-
filename =
|
|
528
|
+
filename = d._fpath.name
|
|
520
529
|
multifile_stitched_predictions[filename] = stitched_predictions
|
|
521
530
|
multifile_stitched_stds[filename] = stitched_stds
|
|
522
531
|
return (
|
|
@@ -534,7 +543,7 @@ def get_predictions(
|
|
|
534
543
|
num_workers=num_workers,
|
|
535
544
|
)
|
|
536
545
|
# get filename without extension and path
|
|
537
|
-
filename =
|
|
546
|
+
filename = dset._fpath.name
|
|
538
547
|
return (
|
|
539
548
|
{filename: stitched_predictions},
|
|
540
549
|
{filename: stitched_stds},
|
|
@@ -553,6 +562,8 @@ def get_single_file_predictions(
|
|
|
553
562
|
if tile_size and grid_size:
|
|
554
563
|
dset.set_img_sz(tile_size, grid_size)
|
|
555
564
|
|
|
565
|
+
device = get_device()
|
|
566
|
+
|
|
556
567
|
dloader = DataLoader(
|
|
557
568
|
dset,
|
|
558
569
|
pin_memory=False,
|
|
@@ -561,14 +572,14 @@ def get_single_file_predictions(
|
|
|
561
572
|
batch_size=batch_size,
|
|
562
573
|
)
|
|
563
574
|
model.eval()
|
|
564
|
-
model.
|
|
575
|
+
model.to(device)
|
|
565
576
|
tiles = []
|
|
566
577
|
logvar_arr = []
|
|
567
578
|
with torch.no_grad():
|
|
568
579
|
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
569
580
|
inp, tar = batch
|
|
570
|
-
inp = inp.
|
|
571
|
-
tar = tar.
|
|
581
|
+
inp = inp.to(device)
|
|
582
|
+
tar = tar.to(device)
|
|
572
583
|
|
|
573
584
|
# get model output
|
|
574
585
|
rec, _ = model(inp)
|
|
@@ -597,6 +608,8 @@ def get_single_file_mmse(
|
|
|
597
608
|
num_workers: int = 4,
|
|
598
609
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
599
610
|
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
611
|
+
device = get_device()
|
|
612
|
+
|
|
600
613
|
dloader = DataLoader(
|
|
601
614
|
dset,
|
|
602
615
|
pin_memory=False,
|
|
@@ -608,15 +621,15 @@ def get_single_file_mmse(
|
|
|
608
621
|
dset.set_img_sz(tile_size, grid_size)
|
|
609
622
|
|
|
610
623
|
model.eval()
|
|
611
|
-
model.
|
|
624
|
+
model.to(device)
|
|
612
625
|
tile_mmse = []
|
|
613
626
|
tile_stds = []
|
|
614
627
|
logvar_arr = []
|
|
615
628
|
with torch.no_grad():
|
|
616
629
|
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
617
630
|
inp, tar = batch
|
|
618
|
-
inp = inp.
|
|
619
|
-
tar = tar.
|
|
631
|
+
inp = inp.to(device)
|
|
632
|
+
tar = tar.to(device)
|
|
620
633
|
|
|
621
634
|
rec_img_list = []
|
|
622
635
|
for _ in range(mmse_count):
|
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Union
|
|
5
5
|
|
|
6
|
+
from careamics.utils.version import get_careamics_version
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
8
10
|
"""Generate unzipped folder path from the bioimage.io model path.
|
|
@@ -44,11 +46,11 @@ def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
|
|
|
44
46
|
f"name: careamics\n"
|
|
45
47
|
f"dependencies:\n"
|
|
46
48
|
f" - python=3.10\n"
|
|
47
|
-
f" - pytorch={pytorch_version}\n"
|
|
48
|
-
f" - torchvision={torchvision_version}\n"
|
|
49
49
|
f" - pip\n"
|
|
50
50
|
f" - pip:\n"
|
|
51
|
-
f" -
|
|
51
|
+
f" - torch=={pytorch_version}\n"
|
|
52
|
+
f" - torchvision=={torchvision_version}\n"
|
|
53
|
+
f" - careamics=={get_careamics_version()}"
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
return env
|
|
@@ -28,14 +28,14 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
28
28
|
WeightsDescr,
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
-
from careamics.config import Configuration,
|
|
31
|
+
from careamics.config import Configuration, DataConfig
|
|
32
32
|
|
|
33
33
|
from ._readme_factory import readme_factory
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def _create_axes(
|
|
37
37
|
array: np.ndarray,
|
|
38
|
-
data_config:
|
|
38
|
+
data_config: DataConfig,
|
|
39
39
|
channel_names: Optional[list[str]] = None,
|
|
40
40
|
is_input: bool = True,
|
|
41
41
|
) -> list[AxisBase]:
|
|
@@ -102,7 +102,7 @@ def _create_axes(
|
|
|
102
102
|
def _create_inputs_ouputs(
|
|
103
103
|
input_array: np.ndarray,
|
|
104
104
|
output_array: np.ndarray,
|
|
105
|
-
data_config:
|
|
105
|
+
data_config: DataConfig,
|
|
106
106
|
input_path: Union[Path, str],
|
|
107
107
|
output_path: Union[Path, str],
|
|
108
108
|
channel_names: Optional[list[str]] = None,
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Optional, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
import pkg_resources
|
|
9
8
|
from bioimageio.core import load_model_description, test_model
|
|
10
9
|
from bioimageio.spec import ValidationSummary, save_bioimageio_package
|
|
11
10
|
from pydantic import HttpUrl
|
|
@@ -16,6 +15,7 @@ from torchvision import __version__ as TORCHVISION_VERSION
|
|
|
16
15
|
from careamics.config import Configuration, load_configuration, save_configuration
|
|
17
16
|
from careamics.config.support import SupportedArchitecture
|
|
18
17
|
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
18
|
+
from careamics.utils.version import get_careamics_version
|
|
19
19
|
|
|
20
20
|
from .bioimage import (
|
|
21
21
|
create_env_text,
|
|
@@ -139,7 +139,7 @@ def export_to_bmz(
|
|
|
139
139
|
path_to_archive.parent.mkdir(parents=True, exist_ok=True)
|
|
140
140
|
|
|
141
141
|
# versions
|
|
142
|
-
careamics_version =
|
|
142
|
+
careamics_version = get_careamics_version()
|
|
143
143
|
|
|
144
144
|
# save files in temporary folder
|
|
145
145
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
@@ -5,7 +5,7 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from careamics.config import Configuration
|
|
8
|
+
from careamics.config import Configuration
|
|
9
9
|
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
@@ -92,4 +92,4 @@ def _load_checkpoint(
|
|
|
92
92
|
f"{cfg_dict['algorithm_config']['model']['architecture']}"
|
|
93
93
|
)
|
|
94
94
|
|
|
95
|
-
return model,
|
|
95
|
+
return model, Configuration(**cfg_dict)
|
careamics/transforms/__init__.py
CHANGED
|
@@ -5,15 +5,16 @@ __all__ = [
|
|
|
5
5
|
"Denormalize",
|
|
6
6
|
"ImageRestorationTTA",
|
|
7
7
|
"N2VManipulate",
|
|
8
|
+
"N2VManipulateTorch",
|
|
8
9
|
"Normalize",
|
|
9
10
|
"XYFlip",
|
|
10
11
|
"XYRandomRotate90",
|
|
11
12
|
"get_all_transforms",
|
|
12
13
|
]
|
|
13
14
|
|
|
14
|
-
|
|
15
15
|
from .compose import Compose, get_all_transforms
|
|
16
16
|
from .n2v_manipulate import N2VManipulate
|
|
17
|
+
from .n2v_manipulate_torch import N2VManipulateTorch
|
|
17
18
|
from .normalize import Denormalize, Normalize
|
|
18
19
|
from .tta import ImageRestorationTTA
|
|
19
20
|
from .xy_flip import XYFlip
|
careamics/transforms/compose.py
CHANGED
|
@@ -6,7 +6,6 @@ from numpy.typing import NDArray
|
|
|
6
6
|
|
|
7
7
|
from careamics.config.transformations import NORM_AND_SPATIAL_UNION
|
|
8
8
|
|
|
9
|
-
from .n2v_manipulate import N2VManipulate
|
|
10
9
|
from .normalize import Normalize
|
|
11
10
|
from .transform import Transform
|
|
12
11
|
from .xy_flip import XYFlip
|
|
@@ -14,7 +13,6 @@ from .xy_random_rotate90 import XYRandomRotate90
|
|
|
14
13
|
|
|
15
14
|
ALL_TRANSFORMS = {
|
|
16
15
|
"Normalize": Normalize,
|
|
17
|
-
"N2VManipulate": N2VManipulate,
|
|
18
16
|
"XYFlip": XYFlip,
|
|
19
17
|
"XYRandomRotate90": XYRandomRotate90,
|
|
20
18
|
}
|
|
@@ -82,21 +80,13 @@ class Compose:
|
|
|
82
80
|
tuple[np.ndarray, Optional[np.ndarray]]
|
|
83
81
|
The output of the transformations.
|
|
84
82
|
"""
|
|
85
|
-
params: Union[
|
|
86
|
-
tuple[NDArray, Optional[NDArray]],
|
|
87
|
-
tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
|
|
88
|
-
] = (patch, target)
|
|
83
|
+
params: Union[tuple[NDArray, Optional[NDArray]],] = (patch, target)
|
|
89
84
|
|
|
90
85
|
for t in self.transforms:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
params = t(patch=patch)
|
|
96
|
-
else:
|
|
97
|
-
*params, _ = t(*params) # ignore additional_arrays dict
|
|
98
|
-
|
|
99
|
-
return params
|
|
86
|
+
*params, _ = t(*params) # ignore additional_arrays dict
|
|
87
|
+
|
|
88
|
+
# avoid None values that create problems for collating
|
|
89
|
+
return tuple(p for p in params if p is not None)
|
|
100
90
|
|
|
101
91
|
def _chain_transforms_additional_arrays(
|
|
102
92
|
self,
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""N2V manipulation transform for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import platform
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
9
|
+
from careamics.config.transformations import N2VManipulateModel
|
|
10
|
+
|
|
11
|
+
from .pixel_manipulation_torch import (
|
|
12
|
+
median_manipulate_torch,
|
|
13
|
+
uniform_manipulate_torch,
|
|
14
|
+
)
|
|
15
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class N2VManipulateTorch:
|
|
19
|
+
"""
|
|
20
|
+
Default augmentation for the N2V model.
|
|
21
|
+
|
|
22
|
+
This transform expects C(Z)YX dimensions.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
n2v_manipulate_config : N2VManipulateConfig
|
|
27
|
+
N2V manipulation configuration.
|
|
28
|
+
seed : Optional[int], optional
|
|
29
|
+
Random seed, by default None.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
masked_pixel_percentage : float
|
|
34
|
+
Percentage of pixels to mask.
|
|
35
|
+
roi_size : int
|
|
36
|
+
Size of the replacement area.
|
|
37
|
+
strategy : Literal[ "uniform", "median" ]
|
|
38
|
+
Replacement strategy, uniform or median.
|
|
39
|
+
remove_center : bool
|
|
40
|
+
Whether to remove central pixel from patch.
|
|
41
|
+
struct_mask : Optional[StructMaskParameters]
|
|
42
|
+
StructN2V mask parameters.
|
|
43
|
+
rng : Generator
|
|
44
|
+
Random number generator.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
n2v_manipulate_config: N2VManipulateModel,
|
|
50
|
+
seed: Optional[int] = None,
|
|
51
|
+
):
|
|
52
|
+
"""Constructor.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
n2v_manipulate_config : N2VManipulateModel
|
|
57
|
+
N2V manipulation configuration.
|
|
58
|
+
seed : Optional[int], optional
|
|
59
|
+
Random seed, by default None.
|
|
60
|
+
"""
|
|
61
|
+
self.masked_pixel_percentage = n2v_manipulate_config.masked_pixel_percentage
|
|
62
|
+
self.roi_size = n2v_manipulate_config.roi_size
|
|
63
|
+
self.strategy = n2v_manipulate_config.strategy
|
|
64
|
+
self.remove_center = n2v_manipulate_config.remove_center
|
|
65
|
+
|
|
66
|
+
if n2v_manipulate_config.struct_mask_axis == SupportedStructAxis.NONE:
|
|
67
|
+
self.struct_mask: Optional[StructMaskParameters] = None
|
|
68
|
+
else:
|
|
69
|
+
self.struct_mask = StructMaskParameters(
|
|
70
|
+
axis=(
|
|
71
|
+
0
|
|
72
|
+
if n2v_manipulate_config.struct_mask_axis
|
|
73
|
+
== SupportedStructAxis.HORIZONTAL
|
|
74
|
+
else 1
|
|
75
|
+
),
|
|
76
|
+
span=n2v_manipulate_config.struct_mask_span,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# PyTorch random generator
|
|
80
|
+
# TODO refactor into careamics.utils.torch_utils.get_device
|
|
81
|
+
if torch.cuda.is_available():
|
|
82
|
+
device = "cuda"
|
|
83
|
+
elif torch.backends.mps.is_available() and platform.processor() in (
|
|
84
|
+
"arm",
|
|
85
|
+
"arm64",
|
|
86
|
+
):
|
|
87
|
+
device = "mps"
|
|
88
|
+
else:
|
|
89
|
+
device = "cpu"
|
|
90
|
+
|
|
91
|
+
self.rng = (
|
|
92
|
+
torch.Generator(device=device).manual_seed(seed)
|
|
93
|
+
if seed is not None
|
|
94
|
+
else torch.Generator(device=device)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __call__(
|
|
98
|
+
self, batch: torch.Tensor, *args: Any, **kwargs: Any
|
|
99
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
100
|
+
"""Apply the transform to the image.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
batch : torch.Tensor
|
|
105
|
+
Batch if image patches, 2D or 3D, shape BC(Z)YX.
|
|
106
|
+
*args : Any
|
|
107
|
+
Additional arguments, unused.
|
|
108
|
+
**kwargs : Any
|
|
109
|
+
Additional keyword arguments, unused.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
114
|
+
Masked patch, original patch, and mask.
|
|
115
|
+
"""
|
|
116
|
+
masked = torch.zeros_like(batch)
|
|
117
|
+
mask = torch.zeros_like(batch, dtype=torch.uint8)
|
|
118
|
+
|
|
119
|
+
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
120
|
+
# Iterate over the channels to apply manipulation separately
|
|
121
|
+
for c in range(batch.shape[1]):
|
|
122
|
+
masked[:, c, ...], mask[:, c, ...] = uniform_manipulate_torch(
|
|
123
|
+
patch=batch[:, c, ...],
|
|
124
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
125
|
+
subpatch_size=self.roi_size,
|
|
126
|
+
remove_center=self.remove_center,
|
|
127
|
+
struct_params=self.struct_mask,
|
|
128
|
+
rng=self.rng,
|
|
129
|
+
)
|
|
130
|
+
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
131
|
+
# Iterate over the channels to apply manipulation separately
|
|
132
|
+
for c in range(batch.shape[1]):
|
|
133
|
+
masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch(
|
|
134
|
+
batch=batch[:, c, ...],
|
|
135
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
136
|
+
subpatch_size=self.roi_size,
|
|
137
|
+
struct_params=self.struct_mask,
|
|
138
|
+
rng=self.rng,
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
142
|
+
|
|
143
|
+
return masked, batch, mask
|