careamics 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/RECORD +73 -62
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""A class to train, predict and export models in CAREamics."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union, overload
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from numpy.typing import NDArray
|
|
@@ -827,7 +828,7 @@ class CAREamist:
|
|
|
827
828
|
source_path = source.pred_data
|
|
828
829
|
source_data_type = source.data_type
|
|
829
830
|
extension_filter = source.extension_filter
|
|
830
|
-
elif isinstance(source,
|
|
831
|
+
elif isinstance(source, str | Path):
|
|
831
832
|
source_path = source
|
|
832
833
|
source_data_type = data_type or self.cfg.data_config.data_type
|
|
833
834
|
extension_filter = SupportedData.get_extension_pattern(
|
|
@@ -840,7 +841,7 @@ class CAREamist:
|
|
|
840
841
|
raise ValueError(
|
|
841
842
|
"Predicting to disk is not supported for input type 'array'."
|
|
842
843
|
)
|
|
843
|
-
assert isinstance(source_path,
|
|
844
|
+
assert isinstance(source_path, str | Path) # because data_type != "array"
|
|
844
845
|
source_path = Path(source_path)
|
|
845
846
|
|
|
846
847
|
file_paths = list_files(source_path, source_data_type, extension_filter)
|
careamics/cli/utils.py
CHANGED
|
@@ -63,6 +63,9 @@ class UNetModel(ArchitectureModel):
|
|
|
63
63
|
"""Whether information is processed independently in each channel, used to train
|
|
64
64
|
channels independently."""
|
|
65
65
|
|
|
66
|
+
use_batch_norm: bool = Field(default=True, validate_default=True)
|
|
67
|
+
"""Whether to use batch normalization in the model."""
|
|
68
|
+
|
|
66
69
|
@field_validator("num_channels_init")
|
|
67
70
|
@classmethod
|
|
68
71
|
def validate_num_channels_init(cls, num_channels_init: int) -> int:
|
|
@@ -22,52 +22,42 @@ class CheckpointModel(BaseModel):
|
|
|
22
22
|
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
model_config = ConfigDict(
|
|
26
|
-
validate_assignment=True,
|
|
27
|
-
)
|
|
25
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
28
26
|
|
|
29
|
-
monitor: Literal["val_loss"] = Field(default="val_loss"
|
|
30
|
-
"""Quantity to monitor
|
|
27
|
+
monitor: Literal["val_loss"] = Field(default="val_loss")
|
|
28
|
+
"""Quantity to monitor, currently only `val_loss`."""
|
|
31
29
|
|
|
32
|
-
verbose: bool = Field(default=False
|
|
30
|
+
verbose: bool = Field(default=False)
|
|
33
31
|
"""Verbosity mode."""
|
|
34
32
|
|
|
35
|
-
save_weights_only: bool = Field(default=False
|
|
33
|
+
save_weights_only: bool = Field(default=False)
|
|
36
34
|
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
37
35
|
|
|
38
|
-
save_last: Optional[Literal[True, False, "link"]] = Field(
|
|
39
|
-
default=True, validate_default=True
|
|
40
|
-
)
|
|
36
|
+
save_last: Optional[Literal[True, False, "link"]] = Field(default=True)
|
|
41
37
|
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
42
38
|
|
|
43
|
-
save_top_k: int = Field(default=3, ge
|
|
39
|
+
save_top_k: int = Field(default=3, ge=-1, le=100)
|
|
44
40
|
"""If `save_top_k == kz, the best k models according to the quantity monitored
|
|
45
41
|
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
42
|
all models are saved."""
|
|
47
43
|
|
|
48
|
-
mode: Literal["min", "max"] = Field(default="min"
|
|
44
|
+
mode: Literal["min", "max"] = Field(default="min")
|
|
49
45
|
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
46
|
save file is made based on either the maximization or the minimization of the
|
|
51
47
|
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
48
|
be 'min', etc.
|
|
53
49
|
"""
|
|
54
50
|
|
|
55
|
-
auto_insert_metric_name: bool = Field(default=False
|
|
51
|
+
auto_insert_metric_name: bool = Field(default=False)
|
|
56
52
|
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
53
|
|
|
58
|
-
every_n_train_steps: Optional[int] = Field(
|
|
59
|
-
default=None, ge=1, le=10, validate_default=True
|
|
60
|
-
)
|
|
54
|
+
every_n_train_steps: Optional[int] = Field(default=None, ge=1, le=1000)
|
|
61
55
|
"""Number of training steps between checkpoints."""
|
|
62
56
|
|
|
63
|
-
train_time_interval: Optional[timedelta] = Field(
|
|
64
|
-
default=None, validate_default=True
|
|
65
|
-
)
|
|
57
|
+
train_time_interval: Optional[timedelta] = Field(default=None)
|
|
66
58
|
"""Checkpoints are monitored at the specified time interval."""
|
|
67
59
|
|
|
68
|
-
every_n_epochs: Optional[int] = Field(
|
|
69
|
-
default=None, ge=1, le=10, validate_default=True
|
|
70
|
-
)
|
|
60
|
+
every_n_epochs: Optional[int] = Field(default=None, ge=1, le=100)
|
|
71
61
|
"""Number of epochs between checkpoints."""
|
|
72
62
|
|
|
73
63
|
|
|
@@ -83,41 +73,40 @@ class EarlyStoppingModel(BaseModel):
|
|
|
83
73
|
|
|
84
74
|
model_config = ConfigDict(
|
|
85
75
|
validate_assignment=True,
|
|
76
|
+
validate_default=True,
|
|
86
77
|
)
|
|
87
78
|
|
|
88
|
-
monitor: Literal["val_loss"] = Field(default="val_loss"
|
|
79
|
+
monitor: Literal["val_loss"] = Field(default="val_loss")
|
|
89
80
|
"""Quantity to monitor."""
|
|
90
81
|
|
|
91
|
-
min_delta: float = Field(default=0.0, ge=0.0, le=1.0
|
|
82
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
|
|
92
83
|
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
93
84
|
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
94
85
|
|
|
95
|
-
patience: int = Field(default=3, ge=1, le=10
|
|
86
|
+
patience: int = Field(default=3, ge=1, le=10)
|
|
96
87
|
"""Number of checks with no improvement after which training will be stopped."""
|
|
97
88
|
|
|
98
|
-
verbose: bool = Field(default=False
|
|
89
|
+
verbose: bool = Field(default=False)
|
|
99
90
|
"""Verbosity mode."""
|
|
100
91
|
|
|
101
|
-
mode: Literal["min", "max", "auto"] = Field(default="min"
|
|
92
|
+
mode: Literal["min", "max", "auto"] = Field(default="min")
|
|
102
93
|
"""One of {min, max, auto}."""
|
|
103
94
|
|
|
104
|
-
check_finite: bool = Field(default=True
|
|
95
|
+
check_finite: bool = Field(default=True)
|
|
105
96
|
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
106
97
|
`inf`."""
|
|
107
98
|
|
|
108
|
-
stopping_threshold: Optional[float] = Field(default=None
|
|
99
|
+
stopping_threshold: Optional[float] = Field(default=None)
|
|
109
100
|
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
110
101
|
|
|
111
|
-
divergence_threshold: Optional[float] = Field(default=None
|
|
102
|
+
divergence_threshold: Optional[float] = Field(default=None)
|
|
112
103
|
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
113
104
|
threshold."""
|
|
114
105
|
|
|
115
|
-
check_on_train_epoch_end: Optional[bool] = Field(
|
|
116
|
-
default=False, validate_default=True
|
|
117
|
-
)
|
|
106
|
+
check_on_train_epoch_end: Optional[bool] = Field(default=False)
|
|
118
107
|
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
119
108
|
`False`, then the check runs at the end of the validation."""
|
|
120
109
|
|
|
121
|
-
log_rank_zero_only: bool = Field(default=False
|
|
110
|
+
log_rank_zero_only: bool = Field(default=False)
|
|
122
111
|
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
123
112
|
process."""
|
|
@@ -3,9 +3,11 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from pprint import pformat
|
|
7
|
-
from typing import Any,
|
|
8
|
+
from typing import Any, Literal, Union
|
|
8
9
|
|
|
10
|
+
import numpy as np
|
|
9
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
10
12
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
11
13
|
from pydantic.main import IncEx
|
|
@@ -183,6 +185,50 @@ class Configuration(BaseModel):
|
|
|
183
185
|
|
|
184
186
|
return name
|
|
185
187
|
|
|
188
|
+
@model_validator(mode="after")
|
|
189
|
+
def validate_n2v_mask_pixel_perc(self: Self) -> Self:
|
|
190
|
+
"""
|
|
191
|
+
Validate that there will always be at least one blind-spot pixel in every patch.
|
|
192
|
+
|
|
193
|
+
The probability of creating a blind-spot pixel is a function of the chosen
|
|
194
|
+
masked pixel percentage and patch size.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
Self
|
|
199
|
+
Validated configuration.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If the probability of masking a pixel within a patch is less than 1 for the
|
|
205
|
+
chosen masked pixel percentage and patch size.
|
|
206
|
+
"""
|
|
207
|
+
# No validation needed for non n2v algorithms
|
|
208
|
+
if not isinstance(self.algorithm_config, N2VAlgorithm):
|
|
209
|
+
return self
|
|
210
|
+
|
|
211
|
+
mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
|
|
212
|
+
patch_size = self.data_config.patch_size
|
|
213
|
+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
|
|
214
|
+
|
|
215
|
+
n_dims = 3 if self.algorithm_config.model.is_3D() else 2
|
|
216
|
+
patch_size_lower_bound = int(np.ceil(expected_area_per_pixel ** (1 / n_dims)))
|
|
217
|
+
required_patch_size = tuple(
|
|
218
|
+
2 ** int(np.ceil(np.log2(patch_size_lower_bound))) for _ in range(n_dims)
|
|
219
|
+
)
|
|
220
|
+
required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
|
|
221
|
+
if expected_area_per_pixel > np.prod(patch_size):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"The probability of creating a blind-spot pixel within a patch is "
|
|
224
|
+
f"below 1, for a patch size of {patch_size} with a masked pixel "
|
|
225
|
+
f"percentage of {mask_pixel_perc}%. Either increase the patch size to "
|
|
226
|
+
f"{required_patch_size} or increase the masked pixel percentage to "
|
|
227
|
+
f"at least {required_mask_pixel_perc}%."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return self
|
|
231
|
+
|
|
186
232
|
@model_validator(mode="after")
|
|
187
233
|
def validate_3D(self: Self) -> Self:
|
|
188
234
|
"""
|