careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,8 +13,8 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
13
13
|
ChannelAxis,
|
|
14
14
|
EnvironmentFileDescr,
|
|
15
15
|
FileDescr,
|
|
16
|
+
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
16
17
|
FixedZeroMeanUnitVarianceDescr,
|
|
17
|
-
FixedZeroMeanUnitVarianceKwargs,
|
|
18
18
|
Identifier,
|
|
19
19
|
InputTensorDescr,
|
|
20
20
|
ModelDescr,
|
|
@@ -134,44 +134,52 @@ def _create_inputs_ouputs(
|
|
|
134
134
|
output_axes = _create_axes(output_array, data_config, channel_names, False)
|
|
135
135
|
|
|
136
136
|
# mean and std
|
|
137
|
-
assert data_config.
|
|
138
|
-
assert data_config.
|
|
139
|
-
|
|
140
|
-
|
|
137
|
+
assert data_config.image_means is not None, "Mean cannot be None."
|
|
138
|
+
assert data_config.image_means is not None, "Std cannot be None."
|
|
139
|
+
means = data_config.image_means
|
|
140
|
+
stds = data_config.image_stds
|
|
141
141
|
|
|
142
142
|
# and the mean and std required to invert the normalization
|
|
143
143
|
# CAREamics denormalization: x = y * (std + eps) + mean
|
|
144
144
|
# BMZ normalization : x = (y - mean') / (std' + eps)
|
|
145
145
|
# to apply the BMZ normalization as a denormalization step, we need:
|
|
146
146
|
eps = 1e-6
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
test_tensor=FileDescr(source=output_path),
|
|
165
|
-
postprocessing=[
|
|
166
|
-
FixedZeroMeanUnitVarianceDescr(
|
|
167
|
-
kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
|
|
168
|
-
mean=inv_mean, std=inv_std
|
|
147
|
+
inv_means = []
|
|
148
|
+
inv_stds = []
|
|
149
|
+
if means and stds:
|
|
150
|
+
for mean, std in zip(means, stds):
|
|
151
|
+
inv_means.append(-mean / (std + eps))
|
|
152
|
+
inv_stds.append(1 / (std + eps) - eps)
|
|
153
|
+
|
|
154
|
+
# create input/output descriptions
|
|
155
|
+
input_descr = InputTensorDescr(
|
|
156
|
+
id=TensorId("input"),
|
|
157
|
+
axes=input_axes,
|
|
158
|
+
test_tensor=FileDescr(source=input_path),
|
|
159
|
+
preprocessing=[
|
|
160
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
161
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
|
|
162
|
+
mean=means, std=stds, axis="channel"
|
|
163
|
+
)
|
|
169
164
|
)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
output_descr = OutputTensorDescr(
|
|
168
|
+
id=TensorId("prediction"),
|
|
169
|
+
axes=output_axes,
|
|
170
|
+
test_tensor=FileDescr(source=output_path),
|
|
171
|
+
postprocessing=[
|
|
172
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
173
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
|
|
174
|
+
mean=inv_means, std=inv_stds, axis="channel"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
],
|
|
178
|
+
)
|
|
173
179
|
|
|
174
|
-
|
|
180
|
+
return input_descr, output_descr
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Mean and std cannot be None.")
|
|
175
183
|
|
|
176
184
|
|
|
177
185
|
def create_model_description(
|
|
@@ -280,7 +288,7 @@ def create_model_description(
|
|
|
280
288
|
"bioimageio": {
|
|
281
289
|
"test_kwargs": {
|
|
282
290
|
"pytorch_state_dict": {
|
|
283
|
-
"decimals":
|
|
291
|
+
"decimals": 0, # ...so we relax the constraints on the decimals
|
|
284
292
|
}
|
|
285
293
|
}
|
|
286
294
|
}
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -178,7 +178,7 @@ def export_to_bmz(
|
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
# test model description
|
|
181
|
-
summary: ValidationSummary = test_model(model_description, decimal=
|
|
181
|
+
summary: ValidationSummary = test_model(model_description, decimal=1)
|
|
182
182
|
if summary.status == "failed":
|
|
183
183
|
raise ValueError(f"Model description test failed: {summary}")
|
|
184
184
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Tuple, Union
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import torch
|
|
7
7
|
|
|
8
8
|
from careamics.config import Configuration
|
|
9
9
|
from careamics.lightning_module import CAREamicsModule
|
|
@@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
|
|
|
64
64
|
If the checkpoint file does not contain hyper parameters (configuration).
|
|
65
65
|
"""
|
|
66
66
|
# load checkpoint
|
|
67
|
-
|
|
67
|
+
# here we might run into issues between devices
|
|
68
|
+
# see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
|
|
69
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
70
|
+
checkpoint: dict = torch.load(path, map_location=device)
|
|
68
71
|
|
|
69
72
|
# attempt to load configuration
|
|
70
73
|
try:
|
|
File without changes
|