careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@ from typing import Callable, Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from .config import
|
|
7
|
+
from .config import MicroSplitDataConfig
|
|
8
8
|
from .lc_dataset import LCMultiChDloader
|
|
9
9
|
from .multich_dataset import MultiChDloader
|
|
10
10
|
from .types import DataSplitType
|
|
@@ -82,7 +82,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
82
82
|
def __init__(
|
|
83
83
|
self,
|
|
84
84
|
preloaded_data: NDArray,
|
|
85
|
-
data_config:
|
|
85
|
+
data_config: MicroSplitDataConfig,
|
|
86
86
|
fpath: str,
|
|
87
87
|
load_data_fn: Callable,
|
|
88
88
|
val_fraction=None,
|
|
@@ -106,7 +106,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
106
106
|
|
|
107
107
|
def load_data(
|
|
108
108
|
self,
|
|
109
|
-
data_config:
|
|
109
|
+
data_config: MicroSplitDataConfig,
|
|
110
110
|
datasplit_type: DataSplitType,
|
|
111
111
|
load_data_fn: Callable,
|
|
112
112
|
val_fraction=None,
|
|
@@ -124,7 +124,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
126
|
preloaded_data: NDArray,
|
|
127
|
-
data_config:
|
|
127
|
+
data_config: MicroSplitDataConfig,
|
|
128
128
|
fpath: str,
|
|
129
129
|
load_data_fn: Callable,
|
|
130
130
|
val_fraction=None,
|
|
@@ -148,7 +148,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
148
148
|
|
|
149
149
|
def load_data(
|
|
150
150
|
self,
|
|
151
|
-
data_config:
|
|
151
|
+
data_config: MicroSplitDataConfig,
|
|
152
152
|
datasplit_type: DataSplitType,
|
|
153
153
|
load_data_fn: Callable[..., NDArray],
|
|
154
154
|
val_fraction=None,
|
|
@@ -175,7 +175,7 @@ class MultiFileDset:
|
|
|
175
175
|
|
|
176
176
|
def __init__(
|
|
177
177
|
self,
|
|
178
|
-
data_config:
|
|
178
|
+
data_config: MicroSplitDataConfig,
|
|
179
179
|
fpath: str,
|
|
180
180
|
load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
|
|
181
181
|
val_fraction=None,
|
|
@@ -32,7 +32,7 @@ class TilingMode:
|
|
|
32
32
|
ShiftBoundary = 2
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
#
|
|
35
|
+
# ------------------------------------------------------------------------------------
|
|
36
36
|
# Function of plotting: TODO -> moved them to another file, plot_utils.py
|
|
37
37
|
def clean_ax(ax):
|
|
38
38
|
"""
|
|
@@ -68,7 +68,9 @@ def get_psnr_str(tar_hsnr, pred, col_idx):
|
|
|
68
68
|
"""
|
|
69
69
|
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
|
|
70
70
|
"""
|
|
71
|
-
|
|
71
|
+
psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
|
|
72
|
+
|
|
73
|
+
return f"{psnr:.1f}"
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
def add_psnr_str(ax_, psnr):
|
|
@@ -129,8 +131,10 @@ def show_for_one(
|
|
|
129
131
|
baseline_preds=None,
|
|
130
132
|
):
|
|
131
133
|
"""
|
|
132
|
-
Given an index, it plots the input, target, reconstructed images and the difference
|
|
133
|
-
|
|
134
|
+
Given an index, it plots the input, target, reconstructed images and the difference
|
|
135
|
+
image.
|
|
136
|
+
Note the the difference image is computed with respect to a ground truth image,
|
|
137
|
+
obtained from the high SNR dataset.
|
|
134
138
|
"""
|
|
135
139
|
highsnr_val_dset.set_img_sz(patch_size, 64)
|
|
136
140
|
highsnr_val_dset.disable_noise()
|
|
@@ -164,7 +168,8 @@ def plot_crops(
|
|
|
164
168
|
for i in range(len(baseline_preds)):
|
|
165
169
|
if baseline_preds[i].shape != tar_hsnr.shape:
|
|
166
170
|
print(
|
|
167
|
-
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not
|
|
171
|
+
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
|
|
172
|
+
f"match target shape {tar_hsnr.shape}"
|
|
168
173
|
)
|
|
169
174
|
print("This happens when we want to predict the edges of the image.")
|
|
170
175
|
return
|
|
@@ -333,14 +338,21 @@ def plot_crops(
|
|
|
333
338
|
ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
|
|
334
339
|
clean_ax(ax_temp)
|
|
335
340
|
|
|
336
|
-
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
|
|
337
|
-
#
|
|
338
|
-
#
|
|
339
|
-
#
|
|
340
|
-
#
|
|
341
|
-
#
|
|
341
|
+
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
|
|
342
|
+
# label='$C_1$')
|
|
343
|
+
# line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
|
|
344
|
+
# label='$C_2$')
|
|
345
|
+
# line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
|
|
346
|
+
# label='Pred')
|
|
347
|
+
# line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
|
|
348
|
+
# linestyle='--', label='$C^N_1$')
|
|
349
|
+
# line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
|
|
350
|
+
# linestyle='--', label='$C^N_2$')
|
|
351
|
+
# legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
|
|
352
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
342
353
|
# prop={'size': 11})
|
|
343
|
-
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
|
|
354
|
+
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
|
|
355
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
344
356
|
# prop={'size': 11})
|
|
345
357
|
|
|
346
358
|
if calibration_stats is not None:
|
|
@@ -383,7 +395,9 @@ def plot_calibration(ax, calibration_stats):
|
|
|
383
395
|
|
|
384
396
|
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
|
|
385
397
|
"""
|
|
386
|
-
Adapted from
|
|
398
|
+
Adapted from
|
|
399
|
+
https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
|
|
400
|
+
matplotlib
|
|
387
401
|
|
|
388
402
|
Function to offset the "center" of a colormap. Useful for
|
|
389
403
|
data with a negative min and positive max and you want the
|
|
@@ -444,7 +458,8 @@ def get_fractional_change(target, prediction, max_val=None):
|
|
|
444
458
|
|
|
445
459
|
def get_zero_centered_midval(error):
|
|
446
460
|
"""
|
|
447
|
-
When done this way, the midval ensures that the colorbar is centered at 0. (Don't
|
|
461
|
+
When done this way, the midval ensures that the colorbar is centered at 0. (Don't
|
|
462
|
+
know how, but it works ;))
|
|
448
463
|
"""
|
|
449
464
|
vmax = error.max()
|
|
450
465
|
vmin = error.min()
|
|
@@ -670,7 +685,7 @@ def get_single_file_mmse(
|
|
|
670
685
|
return stitched_predictions, stitched_stds
|
|
671
686
|
|
|
672
687
|
|
|
673
|
-
#
|
|
688
|
+
# ---------------------------------------------------------------------------------
|
|
674
689
|
### Classes and Functions used to stitch predictions
|
|
675
690
|
class PatchLocation:
|
|
676
691
|
"""
|
|
@@ -698,9 +713,10 @@ def _get_location(extra_padding, hwt, pred_h, pred_w):
|
|
|
698
713
|
|
|
699
714
|
def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
|
|
700
715
|
"""
|
|
701
|
-
For a given idx of the dataset, it returns where exactly in the dataset, does this
|
|
702
|
-
Note that this prediction also has padded pixels and so a subset of
|
|
703
|
-
Which time frame, which spatial location
|
|
716
|
+
For a given idx of the dataset, it returns where exactly in the dataset, does this
|
|
717
|
+
prediction lies. Note that this prediction also has padded pixels and so a subset of
|
|
718
|
+
it will be used in the final prediction. Which time frame, which spatial location
|
|
719
|
+
(h_start, h_end, w_start,w_end)
|
|
704
720
|
Args:
|
|
705
721
|
dset:
|
|
706
722
|
dset_input_idx:
|
|
@@ -785,9 +801,11 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
|
|
|
785
801
|
# NOTE: don't need to compute it for every patch.
|
|
786
802
|
assert (
|
|
787
803
|
smoothening_pixelcount == 0
|
|
788
|
-
), "For smoothing,enable the get_smoothing_mask. It is disabled since I
|
|
804
|
+
), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
|
|
805
|
+
"don't use it and it needs modification to work with non-square images"
|
|
789
806
|
mask = 1
|
|
790
|
-
# mask = _get_smoothing_mask(cropped_pred_i.shape,
|
|
807
|
+
# mask = _get_smoothing_mask(cropped_pred_i.shape,
|
|
808
|
+
# smoothening_pixelcount, loc, frame_size)
|
|
791
809
|
|
|
792
810
|
cropped_pred_list.append(cropped_pred_i)
|
|
793
811
|
|
|
@@ -827,7 +845,8 @@ def stitch_predictions_new(predictions, dset):
|
|
|
827
845
|
output = np.zeros(shape, dtype=predictions.dtype)
|
|
828
846
|
# frame_shape = dset.get_data_shape()[:-1]
|
|
829
847
|
for dset_idx in range(predictions.shape[0]):
|
|
830
|
-
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
848
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
849
|
+
# predictions.shape[-1])
|
|
831
850
|
# grid start, grid end
|
|
832
851
|
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
833
852
|
ge = gs + mng.grid_shape
|
|
@@ -843,7 +862,8 @@ def stitch_predictions_new(predictions, dset):
|
|
|
843
862
|
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
844
863
|
vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
|
|
845
864
|
# assert np.all(vgs == gs)
|
|
846
|
-
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
|
|
865
|
+
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
|
|
866
|
+
# to dig why it's failing at this point !
|
|
847
867
|
# print('VGS')
|
|
848
868
|
# print(gs)
|
|
849
869
|
# print(ge)
|
|
@@ -898,7 +918,8 @@ def stitch_predictions_general(predictions, dset):
|
|
|
898
918
|
# frame_shape = dset.get_data_shape()[:-1]
|
|
899
919
|
for patch_idx in range(predictions.shape[0]):
|
|
900
920
|
# grid start, grid end
|
|
901
|
-
# channel_idx is 0 because during prediction we're only use one channel.
|
|
921
|
+
# channel_idx is 0 because during prediction we're only use one channel.
|
|
922
|
+
# # TODO revisit this
|
|
902
923
|
# 0th dimension is sample index in the output list
|
|
903
924
|
grid_coords = np.array(
|
|
904
925
|
mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
|
|
@@ -906,7 +927,8 @@ def stitch_predictions_general(predictions, dset):
|
|
|
906
927
|
)
|
|
907
928
|
sample_idx = grid_coords[0]
|
|
908
929
|
grid_start = grid_coords[1:]
|
|
909
|
-
# from here on, coordinates are relative to the sample(file in the list of
|
|
930
|
+
# from here on, coordinates are relative to the sample(file in the list of
|
|
931
|
+
# inputs)
|
|
910
932
|
grid_end = grid_start + mng.grid_shape
|
|
911
933
|
|
|
912
934
|
# patch start, patch end
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -186,11 +186,15 @@ def export_to_bmz(
|
|
|
186
186
|
)
|
|
187
187
|
|
|
188
188
|
# test model description
|
|
189
|
-
test_kwargs =
|
|
190
|
-
|
|
191
|
-
.
|
|
192
|
-
|
|
193
|
-
|
|
189
|
+
test_kwargs = {}
|
|
190
|
+
if hasattr(model_description, "config") and isinstance(
|
|
191
|
+
model_description.config, dict
|
|
192
|
+
):
|
|
193
|
+
bioimageio_config = model_description.config.get("bioimageio", {})
|
|
194
|
+
test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
|
|
195
|
+
"pytorch_state_dict", {}
|
|
196
|
+
)
|
|
197
|
+
|
|
194
198
|
summary: ValidationSummary = test_model(model_description, **test_kwargs)
|
|
195
199
|
if summary.status == "failed":
|
|
196
200
|
raise ValueError(f"Model description test failed: {summary}")
|
|
@@ -54,12 +54,8 @@ def likelihood_factory(
|
|
|
54
54
|
)
|
|
55
55
|
elif isinstance(config, NMLikelihoodConfig):
|
|
56
56
|
return NoiseModelLikelihood(
|
|
57
|
-
data_mean=config.data_mean,
|
|
58
|
-
data_std=config.data_std,
|
|
59
57
|
noise_model=noise_model,
|
|
60
58
|
)
|
|
61
|
-
else:
|
|
62
|
-
raise ValueError(f"Invalid likelihood model type: {config.model_type}")
|
|
63
59
|
|
|
64
60
|
|
|
65
61
|
# TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
|
|
@@ -290,27 +286,40 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
290
286
|
|
|
291
287
|
def __init__(
|
|
292
288
|
self,
|
|
293
|
-
data_mean: Union[np.ndarray, torch.Tensor],
|
|
294
|
-
data_std: Union[np.ndarray, torch.Tensor],
|
|
295
289
|
noise_model: NoiseModel,
|
|
296
290
|
):
|
|
297
291
|
"""Constructor.
|
|
298
292
|
|
|
299
293
|
Parameters
|
|
300
294
|
----------
|
|
301
|
-
data_mean: Union[np.ndarray, torch.Tensor]
|
|
302
|
-
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
303
|
-
data_std: Union[np.ndarray, torch.Tensor]
|
|
304
|
-
The standard deviation of the data, used to unnormalize data for noise
|
|
305
|
-
model evaluation.
|
|
306
295
|
noiseModel: NoiseModel
|
|
307
296
|
The noise model instance used to compute the likelihood.
|
|
308
297
|
"""
|
|
309
298
|
super().__init__()
|
|
310
|
-
self.data_mean =
|
|
311
|
-
self.data_std =
|
|
299
|
+
self.data_mean = None
|
|
300
|
+
self.data_std = None
|
|
312
301
|
self.noiseModel = noise_model
|
|
313
302
|
|
|
303
|
+
def set_data_stats(
|
|
304
|
+
self,
|
|
305
|
+
data_mean: Union[np.ndarray, torch.Tensor],
|
|
306
|
+
data_std: Union[np.ndarray, torch.Tensor],
|
|
307
|
+
) -> None:
|
|
308
|
+
"""Set the data mean and std for denormalization.
|
|
309
|
+
# TODO check this !!
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
data_mean : Union[np.ndarray, torch.Tensor]
|
|
313
|
+
Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
|
|
314
|
+
data_std : Union[np.ndarray, torch.Tensor]
|
|
315
|
+
Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
|
|
316
|
+
"""
|
|
317
|
+
# Convert to tensor if needed
|
|
318
|
+
self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
|
|
319
|
+
self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
|
|
320
|
+
|
|
321
|
+
# TODO add extra dim for 3D ?
|
|
322
|
+
|
|
314
323
|
def _set_params_to_same_device_as(
|
|
315
324
|
self, correct_device_tensor: torch.Tensor
|
|
316
325
|
) -> None:
|
|
@@ -321,7 +330,10 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
321
330
|
correct_device_tensor: torch.Tensor
|
|
322
331
|
The tensor whose device is used to set the parameters.
|
|
323
332
|
"""
|
|
324
|
-
if
|
|
333
|
+
if (
|
|
334
|
+
self.data_mean is not None
|
|
335
|
+
and self.data_mean.device != correct_device_tensor.device
|
|
336
|
+
):
|
|
325
337
|
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
326
338
|
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
327
339
|
if correct_device_tensor.device != self.noiseModel.device:
|
|
@@ -367,6 +379,11 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
367
379
|
torch.Tensor
|
|
368
380
|
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
369
381
|
"""
|
|
382
|
+
if self.data_mean is None or self.data_std is None:
|
|
383
|
+
raise RuntimeError(
|
|
384
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before"
|
|
385
|
+
"callinglog_likelihood."
|
|
386
|
+
)
|
|
370
387
|
self._set_params_to_same_device_as(x)
|
|
371
388
|
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
372
389
|
x_denormalized = x * self.data_std + self.data_mean
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from collections.abc import Iterable
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -835,7 +835,7 @@ class LadderVAE(nn.Module):
|
|
|
835
835
|
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
836
836
|
return top_layer_shape
|
|
837
837
|
|
|
838
|
-
def reset_for_inference(self, tile_size:
|
|
838
|
+
def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
|
|
839
839
|
"""Should be called if we want to predict for a different input/output size."""
|
|
840
840
|
self.mode_pred = True
|
|
841
841
|
if tile_size is None:
|
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
|
-
from numpy.typing import NDArray
|
|
7
6
|
import numpy as np
|
|
8
7
|
import torch
|
|
9
8
|
import torch.nn as nn
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
@@ -355,16 +355,16 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
355
355
|
|
|
356
356
|
Parameters
|
|
357
357
|
----------
|
|
358
|
-
x: Tensor
|
|
359
|
-
|
|
360
|
-
mean: Tensor
|
|
361
|
-
|
|
362
|
-
std: Tensor
|
|
363
|
-
|
|
358
|
+
x: torch.Tensor
|
|
359
|
+
The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
|
|
360
|
+
mean: torch.Tensor
|
|
361
|
+
The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
|
|
362
|
+
std: torch.Tensor
|
|
363
|
+
The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
|
|
364
364
|
|
|
365
365
|
Returns
|
|
366
366
|
-------
|
|
367
|
-
tmp: Tensor
|
|
367
|
+
tmp: torch.Tensor
|
|
368
368
|
Normal probability density of `x` given `mean` and `std`
|
|
369
369
|
"""
|
|
370
370
|
tmp = -((x - mean) ** 2)
|
|
@@ -382,9 +382,9 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
382
382
|
Parameters
|
|
383
383
|
----------
|
|
384
384
|
observations : Tensor
|
|
385
|
-
Noisy observations
|
|
385
|
+
Noisy observations. Shape is (batch, 1, dim1, dim2).
|
|
386
386
|
signals : Tensor
|
|
387
|
-
Underlying signals
|
|
387
|
+
Underlying signals. Shape is (batch, 1, dim1, dim2).
|
|
388
388
|
|
|
389
389
|
Returns
|
|
390
390
|
-------
|
|
@@ -392,15 +392,21 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
392
392
|
Likelihood of observations given the signals and the GMM noise model
|
|
393
393
|
"""
|
|
394
394
|
gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
|
|
395
|
-
p =
|
|
395
|
+
p = torch.zeros_like(observations)
|
|
396
396
|
for gaussian in range(self.n_gaussian):
|
|
397
|
+
# Ensure all tensors have compatible shapes
|
|
398
|
+
mean = gaussian_parameters[gaussian]
|
|
399
|
+
std = gaussian_parameters[self.n_gaussian + gaussian]
|
|
400
|
+
weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
|
|
401
|
+
|
|
402
|
+
# Compute normal density
|
|
397
403
|
p += (
|
|
398
404
|
self.normal_density(
|
|
399
405
|
observations,
|
|
400
|
-
|
|
401
|
-
|
|
406
|
+
mean,
|
|
407
|
+
std,
|
|
402
408
|
)
|
|
403
|
-
*
|
|
409
|
+
* weight
|
|
404
410
|
)
|
|
405
411
|
return p + self.tolerance
|
|
406
412
|
|
|
@@ -2,9 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"convert_outputs",
|
|
5
|
+
"convert_outputs_microsplit",
|
|
5
6
|
"stitch_prediction",
|
|
6
7
|
"stitch_prediction_single",
|
|
8
|
+
"stitch_prediction_vae",
|
|
7
9
|
]
|
|
8
10
|
|
|
9
|
-
from .prediction_outputs import convert_outputs
|
|
10
|
-
from .stitch_prediction import
|
|
11
|
+
from .prediction_outputs import convert_outputs, convert_outputs_microsplit
|
|
12
|
+
from .stitch_prediction import (
|
|
13
|
+
stitch_prediction,
|
|
14
|
+
stitch_prediction_single,
|
|
15
|
+
stitch_prediction_vae,
|
|
16
|
+
)
|
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
from numpy.typing import NDArray
|
|
7
7
|
|
|
8
8
|
from ..config.tile_information import TileInformation
|
|
9
|
-
from .stitch_prediction import stitch_prediction
|
|
9
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_vae
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
@@ -41,6 +41,49 @@ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
|
41
41
|
return predictions_output
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
def convert_outputs_microsplit(
|
|
45
|
+
predictions: list[tuple[NDArray, NDArray]], dataset
|
|
46
|
+
) -> tuple[NDArray, NDArray]:
|
|
47
|
+
"""
|
|
48
|
+
Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
|
|
49
|
+
|
|
50
|
+
This function processes microsplit predictions that return (tile_prediction,
|
|
51
|
+
tile_std) tuples and stitches them back together using the same logic as
|
|
52
|
+
get_single_file_mmse.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
predictions : list of tuple[NDArray, NDArray]
|
|
57
|
+
Predictions from Lightning trainer for microsplit. Each element is a tuple of
|
|
58
|
+
(tile_prediction, tile_std) where both are numpy arrays from predict_step.
|
|
59
|
+
dataset : Dataset
|
|
60
|
+
The dataset object used for prediction, needed for stitching function selection
|
|
61
|
+
and stitching process.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
tuple[NDArray, NDArray]
|
|
66
|
+
A tuple of (stitched_predictions, stitched_stds) representing the full
|
|
67
|
+
stitched predictions and standard deviations.
|
|
68
|
+
"""
|
|
69
|
+
if len(predictions) == 0:
|
|
70
|
+
raise ValueError("No predictions provided")
|
|
71
|
+
|
|
72
|
+
# Separate predictions and stds from the list of tuples
|
|
73
|
+
tile_predictions = [pred for pred, _ in predictions]
|
|
74
|
+
tile_stds = [std for _, std in predictions]
|
|
75
|
+
|
|
76
|
+
# Concatenate all tiles exactly like get_single_file_mmse
|
|
77
|
+
tiles_arr = np.concatenate(tile_predictions, axis=0)
|
|
78
|
+
tile_stds_arr = np.concatenate(tile_stds, axis=0)
|
|
79
|
+
|
|
80
|
+
# Apply stitching using stitch_predictions_new
|
|
81
|
+
stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
|
|
82
|
+
stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
|
|
83
|
+
|
|
84
|
+
return stitched_predictions, stitched_stds
|
|
85
|
+
|
|
86
|
+
|
|
44
87
|
# for mypy
|
|
45
88
|
@overload
|
|
46
89
|
def combine_batches( # numpydoc ignore=GL08
|
|
@@ -68,6 +111,8 @@ def combine_batches(
|
|
|
68
111
|
"""
|
|
69
112
|
If predictions are in batches, they will be combined.
|
|
70
113
|
|
|
114
|
+
# TODO improve description!
|
|
115
|
+
|
|
71
116
|
Parameters
|
|
72
117
|
----------
|
|
73
118
|
predictions : list
|
|
@@ -107,11 +152,12 @@ def _combine_tiled_batches(
|
|
|
107
152
|
"""
|
|
108
153
|
# turn list of lists into single list
|
|
109
154
|
tile_infos = [
|
|
110
|
-
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
155
|
+
tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
|
|
111
156
|
]
|
|
112
157
|
prediction_tiles: list[NDArray] = _combine_array_batches(
|
|
113
|
-
[preds for preds, _ in predictions]
|
|
158
|
+
[preds for preds, *_ in predictions]
|
|
114
159
|
)
|
|
160
|
+
|
|
115
161
|
return prediction_tiles, tile_infos
|
|
116
162
|
|
|
117
163
|
|
|
@@ -9,6 +9,86 @@ from numpy.typing import NDArray
|
|
|
9
9
|
from careamics.config.tile_information import TileInformation
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
class TilingMode:
|
|
13
|
+
"""Enum for the tiling mode."""
|
|
14
|
+
|
|
15
|
+
TrimBoundary = 0
|
|
16
|
+
PadBoundary = 1
|
|
17
|
+
ShiftBoundary = 2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def stitch_prediction_vae(predictions, dset) -> NDArray:
|
|
21
|
+
"""
|
|
22
|
+
Stitch predictions back together using dataset's index manager.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
predictions : np.ndarray
|
|
27
|
+
Array of predictions with shape (n_tiles, channels, height, width).
|
|
28
|
+
dset : Dataset
|
|
29
|
+
Dataset object with idx_manager containing tiling information.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
np.ndarray
|
|
34
|
+
Stitched array with shape matching the original data shape.
|
|
35
|
+
"""
|
|
36
|
+
mng = dset.idx_manager
|
|
37
|
+
|
|
38
|
+
# if there are more channels, use all of them.
|
|
39
|
+
shape = list(dset.get_data_shape())
|
|
40
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
41
|
+
|
|
42
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
43
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
44
|
+
for dset_idx in range(predictions.shape[0]):
|
|
45
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
46
|
+
# predictions.shape[-1])
|
|
47
|
+
# grid start, grid end
|
|
48
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
49
|
+
ge = gs + mng.grid_shape
|
|
50
|
+
|
|
51
|
+
# patch start, patch end
|
|
52
|
+
ps = gs - mng.patch_offset()
|
|
53
|
+
pe = ps + mng.patch_shape
|
|
54
|
+
|
|
55
|
+
# valid grid start, valid grid end
|
|
56
|
+
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
57
|
+
vge = np.array(
|
|
58
|
+
[min(x, y) for x, y in zip(ge, mng.data_shape, strict=False)], dtype=int
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
62
|
+
for dim in range(len(vgs)):
|
|
63
|
+
if ps[dim] == 0:
|
|
64
|
+
vgs[dim] = 0
|
|
65
|
+
if pe[dim] == mng.data_shape[dim]:
|
|
66
|
+
vge[dim] = mng.data_shape[dim]
|
|
67
|
+
|
|
68
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
69
|
+
rs = vgs - ps
|
|
70
|
+
re = rs + (vge - vgs)
|
|
71
|
+
|
|
72
|
+
for ch_idx in range(predictions.shape[1]):
|
|
73
|
+
if len(output.shape) == 4:
|
|
74
|
+
# channel dimension is the last one.
|
|
75
|
+
output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
|
|
76
|
+
predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
|
|
77
|
+
)
|
|
78
|
+
elif len(output.shape) == 5:
|
|
79
|
+
# channel dimension is the last one.
|
|
80
|
+
assert vge[0] - vgs[0] == 1, "Only one frame is supported"
|
|
81
|
+
output[
|
|
82
|
+
vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
|
|
83
|
+
] = predictions[dset_idx][
|
|
84
|
+
ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
|
|
85
|
+
]
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
88
|
+
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
|
|
12
92
|
# TODO: why not allow input and output of torch.tensor ?
|
|
13
93
|
def stitch_prediction(
|
|
14
94
|
tiles: list[np.ndarray],
|
|
@@ -107,6 +187,8 @@ def stitch_prediction_single(
|
|
|
107
187
|
|
|
108
188
|
# Insert cropped tile into predicted image using stitch coordinates
|
|
109
189
|
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
110
|
-
|
|
190
|
+
|
|
191
|
+
# TODO fix mypy error here, potentially due to numpy 2
|
|
192
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
|
|
111
193
|
|
|
112
194
|
return predicted_image
|
careamics/utils/version.py
CHANGED
|
@@ -21,9 +21,9 @@ def get_careamics_version() -> str:
|
|
|
21
21
|
parts = __version__.split(".")
|
|
22
22
|
|
|
23
23
|
# for local installs that do not detect the latest versions via tags
|
|
24
|
-
# (typically our CI will install `0.
|
|
25
|
-
if "dev" in parts[
|
|
26
|
-
parts[
|
|
24
|
+
# (typically our CI will install `0.X.devX<hash>.<other hash>` versions)
|
|
25
|
+
if "dev" in parts[2]:
|
|
26
|
+
parts[2] = "*"
|
|
27
27
|
clean_version = ".".join(parts[:3])
|
|
28
28
|
|
|
29
29
|
logger.warning(
|
|
@@ -34,5 +34,5 @@ def get_careamics_version() -> str:
|
|
|
34
34
|
f"closest CAREamics version from PyPI or conda-forge."
|
|
35
35
|
)
|
|
36
36
|
|
|
37
|
-
# Remove any local version identifier
|
|
37
|
+
# Remove any local version identifier
|
|
38
38
|
return ".".join(parts[:3])
|