careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Callback saving CAREamics configuration as hyperparameters in the model."""
|
|
2
|
+
|
|
1
3
|
from pytorch_lightning import LightningModule, Trainer
|
|
2
4
|
from pytorch_lightning.callbacks import Callback
|
|
3
5
|
|
|
@@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
|
|
|
11
13
|
This allows saving the configuration as dictionnary in the checkpoints, and
|
|
12
14
|
loading it subsequently in a CAREamist instance.
|
|
13
15
|
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
config : Configuration
|
|
19
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
20
|
+
|
|
14
21
|
Attributes
|
|
15
22
|
----------
|
|
16
23
|
config : Configuration
|
|
17
24
|
CAREamics configuration to be saved as hyperparameter in the model.
|
|
18
25
|
"""
|
|
19
26
|
|
|
20
|
-
def __init__(self, config: Configuration):
|
|
27
|
+
def __init__(self, config: Configuration) -> None:
|
|
21
28
|
"""
|
|
22
29
|
Constructor.
|
|
23
30
|
|
|
@@ -28,14 +35,14 @@ class HyperParametersCallback(Callback):
|
|
|
28
35
|
"""
|
|
29
36
|
self.config = config
|
|
30
37
|
|
|
31
|
-
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
38
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
32
39
|
"""
|
|
33
40
|
Update the hyperparameters of the model with the configuration on train start.
|
|
34
41
|
|
|
35
42
|
Parameters
|
|
36
43
|
----------
|
|
37
44
|
trainer : Trainer
|
|
38
|
-
PyTorch Lightning trainer.
|
|
45
|
+
PyTorch Lightning trainer, unused.
|
|
39
46
|
pl_module : LightningModule
|
|
40
47
|
PyTorch Lightning module.
|
|
41
48
|
"""
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
1
3
|
import sys
|
|
2
4
|
from typing import Dict, Union
|
|
3
5
|
|
|
@@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
10
12
|
"""Progress bar for training and validation steps."""
|
|
11
13
|
|
|
12
14
|
def init_train_tqdm(self) -> tqdm:
|
|
13
|
-
"""Override this to customize the tqdm bar for training.
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
14
22
|
bar = tqdm(
|
|
15
23
|
desc="Training",
|
|
16
24
|
position=(2 * self.process_position),
|
|
@@ -23,7 +31,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
23
31
|
return bar
|
|
24
32
|
|
|
25
33
|
def init_validation_tqdm(self) -> tqdm:
|
|
26
|
-
"""Override this to customize the tqdm bar for validation.
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
27
41
|
# The main progress bar doesn't exist in `trainer.validate()`
|
|
28
42
|
has_main_bar = self.train_progress_bar is not None
|
|
29
43
|
bar = tqdm(
|
|
@@ -37,7 +51,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
37
51
|
return bar
|
|
38
52
|
|
|
39
53
|
def init_test_tqdm(self) -> tqdm:
|
|
40
|
-
"""Override this to customize the tqdm bar for testing.
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
41
61
|
bar = tqdm(
|
|
42
62
|
desc="Testing",
|
|
43
63
|
position=(2 * self.process_position),
|
|
@@ -52,6 +72,19 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
52
72
|
def get_metrics(
|
|
53
73
|
self, trainer: Trainer, pl_module: LightningModule
|
|
54
74
|
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
|
55
|
-
"""Override this to customize the metrics displayed in the progress bar.
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
56
89
|
pbar_metrics = trainer.progress_bar_metrics
|
|
57
90
|
return {**pbar_metrics}
|