careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +38 -0
- careamics/config/algorithms/n2n_algorithm_model.py +30 -0
- careamics/config/algorithms/n2v_algorithm_model.py +29 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +91 -186
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +5 -4
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +3 -3
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +36 -45
- careamics/losses/__init__.py +3 -3
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -4
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
- careamics-0.0.7.dist-info/RECORD +178 -0
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.5.dist-info/RECORD +0 -171
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,23 +7,18 @@ It includes functions to:
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Optional
|
|
11
11
|
|
|
12
12
|
import matplotlib
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
import numpy as np
|
|
15
|
-
from scipy import stats
|
|
16
15
|
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
from torch.utils.data import Dataset
|
|
19
16
|
from matplotlib.gridspec import GridSpec
|
|
20
|
-
from torch.utils.data import DataLoader
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset, Subset
|
|
21
18
|
from tqdm import tqdm
|
|
22
19
|
|
|
23
20
|
from careamics.lightning import VAEModule
|
|
24
|
-
|
|
25
|
-
from careamics.models.lvae.utils import ModelType
|
|
26
|
-
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
|
|
21
|
+
from careamics.utils.metrics import scale_invariant_psnr
|
|
27
22
|
|
|
28
23
|
|
|
29
24
|
class TilingMode:
|
|
@@ -149,11 +144,10 @@ def plot_crops(
|
|
|
149
144
|
tar,
|
|
150
145
|
tar_hsnr,
|
|
151
146
|
recon_img_list,
|
|
152
|
-
calibration_stats,
|
|
147
|
+
calibration_stats=None,
|
|
153
148
|
num_samples=2,
|
|
154
149
|
baseline_preds=None,
|
|
155
150
|
):
|
|
156
|
-
""" """
|
|
157
151
|
if baseline_preds is None:
|
|
158
152
|
baseline_preds = []
|
|
159
153
|
if len(baseline_preds) > 0:
|
|
@@ -164,15 +158,13 @@ def plot_crops(
|
|
|
164
158
|
)
|
|
165
159
|
print("This happens when we want to predict the edges of the image.")
|
|
166
160
|
return
|
|
161
|
+
color_ch_list = ["goldenrod", "cyan"]
|
|
162
|
+
color_pred = "red"
|
|
163
|
+
insetplot_xmax_value = 10000
|
|
164
|
+
insetplot_xmin_value = -1000
|
|
165
|
+
inset_min_labelsize = 10
|
|
166
|
+
inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
167
167
|
|
|
168
|
-
# color_ch_list = ['goldenrod', 'cyan']
|
|
169
|
-
# color_pred = 'red'
|
|
170
|
-
# insetplot_xmax_value = 10000
|
|
171
|
-
# insetplot_xmin_value = -1000
|
|
172
|
-
# inset_min_labelsize = 10
|
|
173
|
-
# inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
174
|
-
|
|
175
|
-
# Set plot attributes
|
|
176
168
|
img_sz = 3
|
|
177
169
|
ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
|
|
178
170
|
grid_factor = 5
|
|
@@ -191,7 +183,6 @@ def plot_crops(
|
|
|
191
183
|
)
|
|
192
184
|
params = {"mathtext.default": "regular"}
|
|
193
185
|
plt.rcParams.update(params)
|
|
194
|
-
|
|
195
186
|
# plot baselines
|
|
196
187
|
for i in range(2, 2 + len(baseline_preds)):
|
|
197
188
|
for col_idx in range(baseline_preds[0].shape[0]):
|
|
@@ -471,52 +462,17 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val
|
|
|
471
462
|
plt.colorbar(img_err, ax=ax)
|
|
472
463
|
|
|
473
464
|
|
|
474
|
-
#
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
|
|
478
|
-
"""
|
|
479
|
-
Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
|
|
480
|
-
"""
|
|
481
|
-
print(f"Predicting for {idx}")
|
|
482
|
-
val_dset.set_img_sz(patch_size, 64)
|
|
483
|
-
|
|
484
|
-
with torch.no_grad():
|
|
485
|
-
# val_dset.enable_noise()
|
|
486
|
-
inp, tar = val_dset[idx]
|
|
487
|
-
# val_dset.disable_noise()
|
|
488
|
-
|
|
489
|
-
inp = torch.Tensor(inp[None])
|
|
490
|
-
tar = torch.Tensor(tar[None])
|
|
491
|
-
inp = inp.cuda()
|
|
492
|
-
x_normalized = model.normalize_input(inp)
|
|
493
|
-
tar = tar.cuda()
|
|
494
|
-
tar_normalized = model.normalize_target(tar)
|
|
495
|
-
|
|
496
|
-
recon_img_list = []
|
|
497
|
-
for _ in range(mmse_count):
|
|
498
|
-
recon_normalized, td_data = model(x_normalized)
|
|
499
|
-
rec_loss, imgs = model.get_reconstruction_loss(
|
|
500
|
-
recon_normalized,
|
|
501
|
-
x_normalized,
|
|
502
|
-
tar_normalized,
|
|
503
|
-
return_predicted_img=True,
|
|
504
|
-
)
|
|
505
|
-
imgs = model.unnormalize_target(imgs)
|
|
506
|
-
recon_img_list.append(imgs.cpu().numpy()[0])
|
|
507
|
-
|
|
508
|
-
recon_img_list = np.array(recon_img_list)
|
|
509
|
-
return inp, tar, recon_img_list
|
|
465
|
+
# -------------------------------------------------------------------------------------
|
|
510
466
|
|
|
511
467
|
|
|
512
|
-
def
|
|
468
|
+
def get_predictions(
|
|
513
469
|
model: VAEModule,
|
|
514
470
|
dset: Dataset,
|
|
515
471
|
batch_size: int,
|
|
516
|
-
|
|
472
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
517
473
|
mmse_count: int = 1,
|
|
518
474
|
num_workers: int = 4,
|
|
519
|
-
) -> tuple[
|
|
475
|
+
) -> tuple[dict, dict, dict]:
|
|
520
476
|
"""Get patch-wise predictions from a model for the entire dataset.
|
|
521
477
|
|
|
522
478
|
Parameters
|
|
@@ -545,6 +501,55 @@ def get_dset_predictions(
|
|
|
545
501
|
- losses: Reconstruction losses for the predictions.
|
|
546
502
|
- psnr: PSNR values for the predictions.
|
|
547
503
|
"""
|
|
504
|
+
if hasattr(dset, "dsets"):
|
|
505
|
+
multifile_stitched_predictions = {}
|
|
506
|
+
multifile_stitched_stds = {}
|
|
507
|
+
for d in dset.dsets:
|
|
508
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
509
|
+
model=model,
|
|
510
|
+
dset=d,
|
|
511
|
+
batch_size=batch_size,
|
|
512
|
+
tile_size=tile_size,
|
|
513
|
+
mmse_count=mmse_count,
|
|
514
|
+
num_workers=num_workers,
|
|
515
|
+
)
|
|
516
|
+
# get filename without extension and path
|
|
517
|
+
filename = str(d._fpath).split("/")[-1].split(".")[0]
|
|
518
|
+
multifile_stitched_predictions[filename] = stitched_predictions
|
|
519
|
+
multifile_stitched_stds[filename] = stitched_stds
|
|
520
|
+
return (
|
|
521
|
+
multifile_stitched_predictions,
|
|
522
|
+
multifile_stitched_stds,
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
526
|
+
model=model,
|
|
527
|
+
dset=dset,
|
|
528
|
+
batch_size=batch_size,
|
|
529
|
+
tile_size=tile_size,
|
|
530
|
+
mmse_count=mmse_count,
|
|
531
|
+
num_workers=num_workers,
|
|
532
|
+
)
|
|
533
|
+
# get filename without extension and path
|
|
534
|
+
filename = str(dset._fpath).split("/")[-1].split(".")[0]
|
|
535
|
+
return (
|
|
536
|
+
{filename: stitched_predictions},
|
|
537
|
+
{filename: stitched_stds},
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def get_single_file_predictions(
|
|
542
|
+
model: VAEModule,
|
|
543
|
+
dset: Dataset,
|
|
544
|
+
batch_size: int,
|
|
545
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
546
|
+
grid_size: Optional[int] = None,
|
|
547
|
+
num_workers: int = 4,
|
|
548
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
549
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
550
|
+
if tile_size and grid_size:
|
|
551
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
552
|
+
|
|
548
553
|
dloader = DataLoader(
|
|
549
554
|
dset,
|
|
550
555
|
pin_memory=False,
|
|
@@ -552,43 +557,64 @@ def get_dset_predictions(
|
|
|
552
557
|
shuffle=False,
|
|
553
558
|
batch_size=batch_size,
|
|
554
559
|
)
|
|
560
|
+
model.eval()
|
|
561
|
+
model.cuda()
|
|
562
|
+
tiles = []
|
|
563
|
+
logvar_arr = []
|
|
564
|
+
with torch.no_grad():
|
|
565
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
566
|
+
inp, tar = batch
|
|
567
|
+
inp = inp.cuda()
|
|
568
|
+
tar = tar.cuda()
|
|
569
|
+
|
|
570
|
+
# get model output
|
|
571
|
+
rec, _ = model(inp)
|
|
555
572
|
|
|
556
|
-
|
|
557
|
-
|
|
573
|
+
# get reconstructed img
|
|
574
|
+
if model.model.predict_logvar is None:
|
|
575
|
+
rec_img = rec
|
|
576
|
+
logvar = torch.tensor([-1])
|
|
577
|
+
else:
|
|
578
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
579
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
580
|
+
|
|
581
|
+
tiles.append(rec_img.cpu().numpy())
|
|
582
|
+
|
|
583
|
+
tile_samples = np.concatenate(tiles, axis=0)
|
|
584
|
+
return stitch_predictions_new(tile_samples, dset)
|
|
558
585
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
586
|
+
|
|
587
|
+
def get_single_file_mmse(
|
|
588
|
+
model: VAEModule,
|
|
589
|
+
dset: Dataset,
|
|
590
|
+
batch_size: int,
|
|
591
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
592
|
+
mmse_count: int = 1,
|
|
593
|
+
num_workers: int = 4,
|
|
594
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
595
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
596
|
+
dloader = DataLoader(
|
|
597
|
+
dset,
|
|
598
|
+
pin_memory=False,
|
|
599
|
+
num_workers=num_workers,
|
|
600
|
+
shuffle=False,
|
|
601
|
+
batch_size=batch_size,
|
|
602
|
+
)
|
|
603
|
+
if tile_size:
|
|
604
|
+
dset.set_img_sz(tile_size, tile_size[-1] // 2)
|
|
605
|
+
model.eval()
|
|
606
|
+
model.cuda()
|
|
607
|
+
tile_mmse = []
|
|
608
|
+
tile_stds = []
|
|
562
609
|
logvar_arr = []
|
|
563
|
-
num_channels = dset[0][1].shape[0]
|
|
564
|
-
patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
|
|
565
610
|
with torch.no_grad():
|
|
566
|
-
for batch in tqdm(dloader, desc="Predicting
|
|
611
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
567
612
|
inp, tar = batch
|
|
568
613
|
inp = inp.cuda()
|
|
569
614
|
tar = tar.cuda()
|
|
570
615
|
|
|
571
616
|
rec_img_list = []
|
|
572
|
-
for
|
|
573
|
-
|
|
574
|
-
# TODO: case of HDN left for future refactoring
|
|
575
|
-
# if model_type == ModelType.Denoiser:
|
|
576
|
-
# assert model.denoise_channel in [
|
|
577
|
-
# "Ch1",
|
|
578
|
-
# "Ch2",
|
|
579
|
-
# "input",
|
|
580
|
-
# ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
|
|
581
|
-
|
|
582
|
-
# x_normalized_new, tar_new = model.get_new_input_target(
|
|
583
|
-
# (inp, tar, *batch[2:])
|
|
584
|
-
# )
|
|
585
|
-
# rec, _ = model(x_normalized_new)
|
|
586
|
-
# rec_loss, imgs = model.get_reconstruction_loss(
|
|
587
|
-
# rec,
|
|
588
|
-
# tar,
|
|
589
|
-
# x_normalized_new,
|
|
590
|
-
# return_predicted_img=True,
|
|
591
|
-
# )
|
|
617
|
+
for _ in range(mmse_count):
|
|
592
618
|
|
|
593
619
|
# get model output
|
|
594
620
|
rec, _ = model(inp)
|
|
@@ -600,52 +626,21 @@ def get_dset_predictions(
|
|
|
600
626
|
else:
|
|
601
627
|
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
602
628
|
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
603
|
-
logvar_arr.append(logvar.cpu().numpy())
|
|
604
|
-
|
|
605
|
-
# compute reconstruction loss
|
|
606
|
-
# if loss_type == "musplit":
|
|
607
|
-
# rec_loss = get_reconstruction_loss(
|
|
608
|
-
# reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
|
|
609
|
-
# )
|
|
610
|
-
# elif loss_type == "denoisplit":
|
|
611
|
-
# rec_loss = get_reconstruction_loss(
|
|
612
|
-
# reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
|
|
613
|
-
# )
|
|
614
|
-
# elif loss_type == "denoisplit_musplit":
|
|
615
|
-
# rec_loss = reconstruction_loss_musplit_denoisplit(
|
|
616
|
-
# predictions=rec,
|
|
617
|
-
# targets=tar,
|
|
618
|
-
# gaussian_likelihood=gauss_likelihood,
|
|
619
|
-
# nm_likelihood=nm_likelihood,
|
|
620
|
-
# nm_weight=model.loss_parameters.denoisplit_weight,
|
|
621
|
-
# gaussian_weight=model.loss_parameters.musplit_weight,
|
|
622
|
-
# )
|
|
623
|
-
# rec_loss = {"loss": rec_loss} # hacky, but ok for now
|
|
624
|
-
|
|
625
|
-
# # store rec loss values for first pred
|
|
626
|
-
# if mmse_idx == 0:
|
|
627
|
-
# try:
|
|
628
|
-
# losses.append(rec_loss["loss"].cpu().numpy())
|
|
629
|
-
# except:
|
|
630
|
-
# losses.append(rec_loss["loss"])
|
|
631
|
-
|
|
632
|
-
# update running PSNR
|
|
633
|
-
# for i in range(num_channels):
|
|
634
|
-
# patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
|
|
629
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
635
630
|
|
|
636
631
|
# aggregate results
|
|
637
632
|
samples = torch.cat(rec_img_list, dim=0)
|
|
638
633
|
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
634
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
635
|
+
|
|
636
|
+
tile_mmse.append(mmse_imgs.cpu().numpy())
|
|
637
|
+
tile_stds.append(std_imgs.cpu().numpy())
|
|
638
|
+
|
|
639
|
+
tiles_arr = np.concatenate(tile_mmse, axis=0)
|
|
640
|
+
tile_stds = np.concatenate(tile_stds, axis=0)
|
|
641
|
+
stitched_predictions = stitch_predictions_new(tiles_arr, dset)
|
|
642
|
+
stitched_stds = stitch_predictions_new(tile_stds, dset)
|
|
643
|
+
return stitched_predictions, stitched_stds
|
|
649
644
|
|
|
650
645
|
|
|
651
646
|
# ------------------------------------------------------------------------------------------
|
careamics/model_io/__init__.py
CHANGED
|
@@ -55,7 +55,7 @@ def readme_factory(
|
|
|
55
55
|
readme.touch()
|
|
56
56
|
|
|
57
57
|
# algorithm pretty name
|
|
58
|
-
algorithm_flavour = config.
|
|
58
|
+
algorithm_flavour = config.get_algorithm_friendly_name()
|
|
59
59
|
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
60
60
|
|
|
61
61
|
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Module use to build BMZ model description."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Optional, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from bioimageio.spec._internal.io import resolve_and_extract
|
|
@@ -28,17 +28,17 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
28
28
|
WeightsDescr,
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
-
from careamics.config import Configuration,
|
|
31
|
+
from careamics.config import Configuration, GeneralDataConfig
|
|
32
32
|
|
|
33
33
|
from ._readme_factory import readme_factory
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def _create_axes(
|
|
37
37
|
array: np.ndarray,
|
|
38
|
-
data_config:
|
|
39
|
-
channel_names: Optional[
|
|
38
|
+
data_config: GeneralDataConfig,
|
|
39
|
+
channel_names: Optional[list[str]] = None,
|
|
40
40
|
is_input: bool = True,
|
|
41
|
-
) ->
|
|
41
|
+
) -> list[AxisBase]:
|
|
42
42
|
"""Create axes description.
|
|
43
43
|
|
|
44
44
|
Array shape is expected to be SC(Z)YX.
|
|
@@ -49,15 +49,15 @@ def _create_axes(
|
|
|
49
49
|
Array.
|
|
50
50
|
data_config : DataModel
|
|
51
51
|
CAREamics data configuration.
|
|
52
|
-
channel_names : Optional[
|
|
52
|
+
channel_names : Optional[list[str]], optional
|
|
53
53
|
Channel names, by default None.
|
|
54
54
|
is_input : bool, optional
|
|
55
55
|
Whether the axes are input axes, by default True.
|
|
56
56
|
|
|
57
57
|
Returns
|
|
58
58
|
-------
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
list[AxisBase]
|
|
60
|
+
list of axes description.
|
|
61
61
|
|
|
62
62
|
Raises
|
|
63
63
|
------
|
|
@@ -102,11 +102,11 @@ def _create_axes(
|
|
|
102
102
|
def _create_inputs_ouputs(
|
|
103
103
|
input_array: np.ndarray,
|
|
104
104
|
output_array: np.ndarray,
|
|
105
|
-
data_config:
|
|
105
|
+
data_config: GeneralDataConfig,
|
|
106
106
|
input_path: Union[Path, str],
|
|
107
107
|
output_path: Union[Path, str],
|
|
108
|
-
channel_names: Optional[
|
|
109
|
-
) ->
|
|
108
|
+
channel_names: Optional[list[str]] = None,
|
|
109
|
+
) -> tuple[InputTensorDescr, OutputTensorDescr]:
|
|
110
110
|
"""Create input and output tensor description.
|
|
111
111
|
|
|
112
112
|
Input and output paths must point to a `.npy` file.
|
|
@@ -123,12 +123,12 @@ def _create_inputs_ouputs(
|
|
|
123
123
|
Path to input .npy file.
|
|
124
124
|
output_path : Union[Path, str]
|
|
125
125
|
Path to output .npy file.
|
|
126
|
-
channel_names : Optional[
|
|
126
|
+
channel_names : Optional[list[str]], optional
|
|
127
127
|
Channel names, by default None.
|
|
128
128
|
|
|
129
129
|
Returns
|
|
130
130
|
-------
|
|
131
|
-
|
|
131
|
+
tuple[InputTensorDescr, OutputTensorDescr]
|
|
132
132
|
Input and output tensor descriptions.
|
|
133
133
|
"""
|
|
134
134
|
input_axes = _create_axes(input_array, data_config, channel_names)
|
|
@@ -188,7 +188,7 @@ def create_model_description(
|
|
|
188
188
|
name: str,
|
|
189
189
|
general_description: str,
|
|
190
190
|
data_description: str,
|
|
191
|
-
authors:
|
|
191
|
+
authors: list[Author],
|
|
192
192
|
inputs: Union[Path, str],
|
|
193
193
|
outputs: Union[Path, str],
|
|
194
194
|
weights_path: Union[Path, str],
|
|
@@ -197,7 +197,7 @@ def create_model_description(
|
|
|
197
197
|
config_path: Union[Path, str],
|
|
198
198
|
env_path: Union[Path, str],
|
|
199
199
|
covers: list[Union[Path, str]],
|
|
200
|
-
channel_names: Optional[
|
|
200
|
+
channel_names: Optional[list[str]] = None,
|
|
201
201
|
model_version: str = "0.1.0",
|
|
202
202
|
) -> ModelDescr:
|
|
203
203
|
"""Create model description.
|
|
@@ -212,7 +212,7 @@ def create_model_description(
|
|
|
212
212
|
General description of the model.
|
|
213
213
|
data_description : str
|
|
214
214
|
Description of the data the model was trained on.
|
|
215
|
-
authors :
|
|
215
|
+
authors : list[Author]
|
|
216
216
|
Authors of the model.
|
|
217
217
|
inputs : Union[Path, str]
|
|
218
218
|
Path to input .npy file.
|
|
@@ -230,7 +230,7 @@ def create_model_description(
|
|
|
230
230
|
Path to environment file.
|
|
231
231
|
covers : list of pathlib.Path or str
|
|
232
232
|
Paths to cover images.
|
|
233
|
-
channel_names : Optional[
|
|
233
|
+
channel_names : Optional[list[str]], optional
|
|
234
234
|
Channel names, by default None.
|
|
235
235
|
model_version : str, default "0.1.0"
|
|
236
236
|
Model version.
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Optional, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pkg_resources
|
|
@@ -87,11 +87,11 @@ def export_to_bmz(
|
|
|
87
87
|
model_name: str,
|
|
88
88
|
general_description: str,
|
|
89
89
|
data_description: str,
|
|
90
|
-
authors:
|
|
90
|
+
authors: list[dict],
|
|
91
91
|
input_array: np.ndarray,
|
|
92
92
|
output_array: np.ndarray,
|
|
93
93
|
covers: Optional[list[Union[Path, str]]] = None,
|
|
94
|
-
channel_names: Optional[
|
|
94
|
+
channel_names: Optional[list[str]] = None,
|
|
95
95
|
model_version: str = "0.1.0",
|
|
96
96
|
) -> None:
|
|
97
97
|
"""Export the model to BioImage Model Zoo format.
|
|
@@ -115,7 +115,7 @@ def export_to_bmz(
|
|
|
115
115
|
General description of the model.
|
|
116
116
|
data_description : str
|
|
117
117
|
Description of the data the model was trained on.
|
|
118
|
-
authors :
|
|
118
|
+
authors : list[dict]
|
|
119
119
|
Authors of the model.
|
|
120
120
|
input_array : np.ndarray
|
|
121
121
|
Input array, should not have been normalized.
|
|
@@ -123,24 +123,13 @@ def export_to_bmz(
|
|
|
123
123
|
Output array, should have been denormalized.
|
|
124
124
|
covers : list of pathlib.Path or str, default=None
|
|
125
125
|
Paths to the cover images.
|
|
126
|
-
channel_names : Optional[
|
|
126
|
+
channel_names : Optional[list[str]], optional
|
|
127
127
|
Channel names, by default None.
|
|
128
128
|
model_version : str, default="0.1.0"
|
|
129
129
|
Model version.
|
|
130
|
-
|
|
131
|
-
Raises
|
|
132
|
-
------
|
|
133
|
-
ValueError
|
|
134
|
-
If the model is a Custom model.
|
|
135
130
|
"""
|
|
136
131
|
path_to_archive = Path(path_to_archive)
|
|
137
132
|
|
|
138
|
-
# method is not compatible with Custom models
|
|
139
|
-
if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
|
|
140
|
-
raise ValueError(
|
|
141
|
-
"Exporting Custom models to BioImage Model Zoo format is not supported."
|
|
142
|
-
)
|
|
143
|
-
|
|
144
133
|
if path_to_archive.suffix != ".zip":
|
|
145
134
|
raise ValueError(
|
|
146
135
|
f"Path to archive must point to a zip file, got {path_to_archive}."
|
|
@@ -212,7 +201,7 @@ def export_to_bmz(
|
|
|
212
201
|
|
|
213
202
|
def load_from_bmz(
|
|
214
203
|
path: Union[Path, str, HttpUrl]
|
|
215
|
-
) ->
|
|
204
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
216
205
|
"""Load a model from a BioImage Model Zoo archive.
|
|
217
206
|
|
|
218
207
|
Parameters
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Utility functions to load pretrained models."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from careamics.config import Configuration
|
|
8
|
+
from careamics.config import Configuration, configuration_factory
|
|
9
9
|
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
@@ -13,7 +13,7 @@ from careamics.utils import check_path_exists
|
|
|
13
13
|
|
|
14
14
|
def load_pretrained(
|
|
15
15
|
path: Union[Path, str]
|
|
16
|
-
) ->
|
|
16
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
17
17
|
"""
|
|
18
18
|
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
19
19
|
|
|
@@ -26,8 +26,8 @@ def load_pretrained(
|
|
|
26
26
|
|
|
27
27
|
Returns
|
|
28
28
|
-------
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
tuple[CAREamicsKiln, Configuration]
|
|
30
|
+
tuple of CAREamics model and its configuration.
|
|
31
31
|
|
|
32
32
|
Raises
|
|
33
33
|
------
|
|
@@ -48,7 +48,7 @@ def load_pretrained(
|
|
|
48
48
|
|
|
49
49
|
def _load_checkpoint(
|
|
50
50
|
path: Union[Path, str]
|
|
51
|
-
) ->
|
|
51
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
52
52
|
"""
|
|
53
53
|
Load a model from a checkpoint and return both model and configuration.
|
|
54
54
|
|
|
@@ -59,8 +59,8 @@ def _load_checkpoint(
|
|
|
59
59
|
|
|
60
60
|
Returns
|
|
61
61
|
-------
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
tuple[CAREamicsKiln, Configuration]
|
|
63
|
+
tuple of CAREamics model and its configuration.
|
|
64
64
|
|
|
65
65
|
Raises
|
|
66
66
|
------
|
|
@@ -92,4 +92,4 @@ def _load_checkpoint(
|
|
|
92
92
|
f"{cfg_dict['algorithm_config']['model']['architecture']}"
|
|
93
93
|
)
|
|
94
94
|
|
|
95
|
-
return model,
|
|
95
|
+
return model, configuration_factory(cfg_dict)
|