careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -7,23 +7,18 @@ It includes functions to:
7
7
  """
8
8
 
9
9
  import os
10
- from typing import List, Literal, Union
10
+ from typing import Optional
11
11
 
12
12
  import matplotlib
13
13
  import matplotlib.pyplot as plt
14
14
  import numpy as np
15
- from scipy import stats
16
15
  import torch
17
- from torch import nn
18
- from torch.utils.data import Dataset
19
16
  from matplotlib.gridspec import GridSpec
20
- from torch.utils.data import DataLoader
17
+ from torch.utils.data import DataLoader, Dataset, Subset
21
18
  from tqdm import tqdm
22
19
 
23
20
  from careamics.lightning import VAEModule
24
-
25
- from careamics.models.lvae.utils import ModelType
26
- from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
21
+ from careamics.utils.metrics import scale_invariant_psnr
27
22
 
28
23
 
29
24
  class TilingMode:
@@ -149,11 +144,10 @@ def plot_crops(
149
144
  tar,
150
145
  tar_hsnr,
151
146
  recon_img_list,
152
- calibration_stats,
147
+ calibration_stats=None,
153
148
  num_samples=2,
154
149
  baseline_preds=None,
155
150
  ):
156
- """ """
157
151
  if baseline_preds is None:
158
152
  baseline_preds = []
159
153
  if len(baseline_preds) > 0:
@@ -164,15 +158,13 @@ def plot_crops(
164
158
  )
165
159
  print("This happens when we want to predict the edges of the image.")
166
160
  return
161
+ color_ch_list = ["goldenrod", "cyan"]
162
+ color_pred = "red"
163
+ insetplot_xmax_value = 10000
164
+ insetplot_xmin_value = -1000
165
+ inset_min_labelsize = 10
166
+ inset_rect = [0.05, 0.05, 0.4, 0.2]
167
167
 
168
- # color_ch_list = ['goldenrod', 'cyan']
169
- # color_pred = 'red'
170
- # insetplot_xmax_value = 10000
171
- # insetplot_xmin_value = -1000
172
- # inset_min_labelsize = 10
173
- # inset_rect = [0.05, 0.05, 0.4, 0.2]
174
-
175
- # Set plot attributes
176
168
  img_sz = 3
177
169
  ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
178
170
  grid_factor = 5
@@ -191,7 +183,6 @@ def plot_crops(
191
183
  )
192
184
  params = {"mathtext.default": "regular"}
193
185
  plt.rcParams.update(params)
194
-
195
186
  # plot baselines
196
187
  for i in range(2, 2 + len(baseline_preds)):
197
188
  for col_idx in range(baseline_preds[0].shape[0]):
@@ -471,52 +462,17 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val
471
462
  plt.colorbar(img_err, ax=ax)
472
463
 
473
464
 
474
- # ------------------------------------------------------------------------------------------------
475
-
476
-
477
- def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
478
- """
479
- Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
480
- """
481
- print(f"Predicting for {idx}")
482
- val_dset.set_img_sz(patch_size, 64)
483
-
484
- with torch.no_grad():
485
- # val_dset.enable_noise()
486
- inp, tar = val_dset[idx]
487
- # val_dset.disable_noise()
488
-
489
- inp = torch.Tensor(inp[None])
490
- tar = torch.Tensor(tar[None])
491
- inp = inp.cuda()
492
- x_normalized = model.normalize_input(inp)
493
- tar = tar.cuda()
494
- tar_normalized = model.normalize_target(tar)
495
-
496
- recon_img_list = []
497
- for _ in range(mmse_count):
498
- recon_normalized, td_data = model(x_normalized)
499
- rec_loss, imgs = model.get_reconstruction_loss(
500
- recon_normalized,
501
- x_normalized,
502
- tar_normalized,
503
- return_predicted_img=True,
504
- )
505
- imgs = model.unnormalize_target(imgs)
506
- recon_img_list.append(imgs.cpu().numpy()[0])
507
-
508
- recon_img_list = np.array(recon_img_list)
509
- return inp, tar, recon_img_list
465
+ # -------------------------------------------------------------------------------------
510
466
 
511
467
 
512
- def get_dset_predictions(
468
+ def get_predictions(
513
469
  model: VAEModule,
514
470
  dset: Dataset,
515
471
  batch_size: int,
516
- loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
472
+ tile_size: Optional[tuple[int, int]] = None,
517
473
  mmse_count: int = 1,
518
474
  num_workers: int = 4,
519
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]:
475
+ ) -> tuple[dict, dict, dict]:
520
476
  """Get patch-wise predictions from a model for the entire dataset.
521
477
 
522
478
  Parameters
@@ -545,6 +501,55 @@ def get_dset_predictions(
545
501
  - losses: Reconstruction losses for the predictions.
546
502
  - psnr: PSNR values for the predictions.
547
503
  """
504
+ if hasattr(dset, "dsets"):
505
+ multifile_stitched_predictions = {}
506
+ multifile_stitched_stds = {}
507
+ for d in dset.dsets:
508
+ stitched_predictions, stitched_stds = get_single_file_mmse(
509
+ model=model,
510
+ dset=d,
511
+ batch_size=batch_size,
512
+ tile_size=tile_size,
513
+ mmse_count=mmse_count,
514
+ num_workers=num_workers,
515
+ )
516
+ # get filename without extension and path
517
+ filename = str(d._fpath).split("/")[-1].split(".")[0]
518
+ multifile_stitched_predictions[filename] = stitched_predictions
519
+ multifile_stitched_stds[filename] = stitched_stds
520
+ return (
521
+ multifile_stitched_predictions,
522
+ multifile_stitched_stds,
523
+ )
524
+ else:
525
+ stitched_predictions, stitched_stds = get_single_file_mmse(
526
+ model=model,
527
+ dset=dset,
528
+ batch_size=batch_size,
529
+ tile_size=tile_size,
530
+ mmse_count=mmse_count,
531
+ num_workers=num_workers,
532
+ )
533
+ # get filename without extension and path
534
+ filename = str(dset._fpath).split("/")[-1].split(".")[0]
535
+ return (
536
+ {filename: stitched_predictions},
537
+ {filename: stitched_stds},
538
+ )
539
+
540
+
541
+ def get_single_file_predictions(
542
+ model: VAEModule,
543
+ dset: Dataset,
544
+ batch_size: int,
545
+ tile_size: Optional[tuple[int, int]] = None,
546
+ grid_size: Optional[int] = None,
547
+ num_workers: int = 4,
548
+ ) -> tuple[np.ndarray, np.ndarray]:
549
+ """Get patch-wise predictions from a model for a single file dataset."""
550
+ if tile_size and grid_size:
551
+ dset.set_img_sz(tile_size, grid_size)
552
+
548
553
  dloader = DataLoader(
549
554
  dset,
550
555
  pin_memory=False,
@@ -552,43 +557,64 @@ def get_dset_predictions(
552
557
  shuffle=False,
553
558
  batch_size=batch_size,
554
559
  )
560
+ model.eval()
561
+ model.cuda()
562
+ tiles = []
563
+ logvar_arr = []
564
+ with torch.no_grad():
565
+ for batch in tqdm(dloader, desc="Predicting tiles"):
566
+ inp, tar = batch
567
+ inp = inp.cuda()
568
+ tar = tar.cuda()
569
+
570
+ # get model output
571
+ rec, _ = model(inp)
555
572
 
556
- gauss_likelihood = model.gaussian_likelihood
557
- nm_likelihood = model.noise_model_likelihood
573
+ # get reconstructed img
574
+ if model.model.predict_logvar is None:
575
+ rec_img = rec
576
+ logvar = torch.tensor([-1])
577
+ else:
578
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
579
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
580
+
581
+ tiles.append(rec_img.cpu().numpy())
582
+
583
+ tile_samples = np.concatenate(tiles, axis=0)
584
+ return stitch_predictions_new(tile_samples, dset)
558
585
 
559
- predictions = []
560
- predictions_std = []
561
- losses = []
586
+
587
+ def get_single_file_mmse(
588
+ model: VAEModule,
589
+ dset: Dataset,
590
+ batch_size: int,
591
+ tile_size: Optional[tuple[int, int]] = None,
592
+ mmse_count: int = 1,
593
+ num_workers: int = 4,
594
+ ) -> tuple[np.ndarray, np.ndarray]:
595
+ """Get patch-wise predictions from a model for a single file dataset."""
596
+ dloader = DataLoader(
597
+ dset,
598
+ pin_memory=False,
599
+ num_workers=num_workers,
600
+ shuffle=False,
601
+ batch_size=batch_size,
602
+ )
603
+ if tile_size:
604
+ dset.set_img_sz(tile_size, tile_size[-1] // 2)
605
+ model.eval()
606
+ model.cuda()
607
+ tile_mmse = []
608
+ tile_stds = []
562
609
  logvar_arr = []
563
- num_channels = dset[0][1].shape[0]
564
- patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
565
610
  with torch.no_grad():
566
- for batch in tqdm(dloader, desc="Predicting patches"):
611
+ for batch in tqdm(dloader, desc="Predicting tiles"):
567
612
  inp, tar = batch
568
613
  inp = inp.cuda()
569
614
  tar = tar.cuda()
570
615
 
571
616
  rec_img_list = []
572
- for mmse_idx in range(mmse_count):
573
-
574
- # TODO: case of HDN left for future refactoring
575
- # if model_type == ModelType.Denoiser:
576
- # assert model.denoise_channel in [
577
- # "Ch1",
578
- # "Ch2",
579
- # "input",
580
- # ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
581
-
582
- # x_normalized_new, tar_new = model.get_new_input_target(
583
- # (inp, tar, *batch[2:])
584
- # )
585
- # rec, _ = model(x_normalized_new)
586
- # rec_loss, imgs = model.get_reconstruction_loss(
587
- # rec,
588
- # tar,
589
- # x_normalized_new,
590
- # return_predicted_img=True,
591
- # )
617
+ for _ in range(mmse_count):
592
618
 
593
619
  # get model output
594
620
  rec, _ = model(inp)
@@ -600,52 +626,21 @@ def get_dset_predictions(
600
626
  else:
601
627
  rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
602
628
  rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
603
- logvar_arr.append(logvar.cpu().numpy())
604
-
605
- # compute reconstruction loss
606
- # if loss_type == "musplit":
607
- # rec_loss = get_reconstruction_loss(
608
- # reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
609
- # )
610
- # elif loss_type == "denoisplit":
611
- # rec_loss = get_reconstruction_loss(
612
- # reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
613
- # )
614
- # elif loss_type == "denoisplit_musplit":
615
- # rec_loss = reconstruction_loss_musplit_denoisplit(
616
- # predictions=rec,
617
- # targets=tar,
618
- # gaussian_likelihood=gauss_likelihood,
619
- # nm_likelihood=nm_likelihood,
620
- # nm_weight=model.loss_parameters.denoisplit_weight,
621
- # gaussian_weight=model.loss_parameters.musplit_weight,
622
- # )
623
- # rec_loss = {"loss": rec_loss} # hacky, but ok for now
624
-
625
- # # store rec loss values for first pred
626
- # if mmse_idx == 0:
627
- # try:
628
- # losses.append(rec_loss["loss"].cpu().numpy())
629
- # except:
630
- # losses.append(rec_loss["loss"])
631
-
632
- # update running PSNR
633
- # for i in range(num_channels):
634
- # patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
629
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
635
630
 
636
631
  # aggregate results
637
632
  samples = torch.cat(rec_img_list, dim=0)
638
633
  mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
639
- # mmse_std = torch.std(samples, dim=0)
640
- predictions.append(mmse_imgs.cpu().numpy())
641
- # predictions_std.append(mmse_std.cpu().numpy())
642
-
643
- # psnr = [x.get() for x in patch_psnr_channels]
644
- return np.concatenate(predictions, axis=0)
645
- # np.concatenate(predictions_std, axis=0),
646
- # np.concatenate(logvar_arr),
647
- # np.array(losses),
648
- # psnr, # TODO revisit !
634
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
635
+
636
+ tile_mmse.append(mmse_imgs.cpu().numpy())
637
+ tile_stds.append(std_imgs.cpu().numpy())
638
+
639
+ tiles_arr = np.concatenate(tile_mmse, axis=0)
640
+ tile_stds = np.concatenate(tile_stds, axis=0)
641
+ stitched_predictions = stitch_predictions_new(tiles_arr, dset)
642
+ stitched_stds = stitch_predictions_new(tile_stds, dset)
643
+ return stitched_predictions, stitched_stds
649
644
 
650
645
 
651
646
  # ------------------------------------------------------------------------------------------
@@ -1,6 +1,6 @@
1
1
  """Model I/O utilities."""
2
2
 
3
- __all__ = ["load_pretrained", "export_to_bmz"]
3
+ __all__ = ["export_to_bmz", "load_pretrained"]
4
4
 
5
5
 
6
6
  from .bmz_io import export_to_bmz
@@ -1,10 +1,10 @@
1
1
  """Bioimage Model Zoo format functions."""
2
2
 
3
3
  __all__ = [
4
+ "create_env_text",
4
5
  "create_model_description",
5
6
  "extract_model_path",
6
7
  "get_unzip_path",
7
- "create_env_text",
8
8
  ]
9
9
 
10
10
  from .bioimage_utils import create_env_text, get_unzip_path
@@ -55,7 +55,7 @@ def readme_factory(
55
55
  readme.touch()
56
56
 
57
57
  # algorithm pretty name
58
- algorithm_flavour = config.get_algorithm_flavour()
58
+ algorithm_flavour = config.get_algorithm_friendly_name()
59
59
  algorithm_pretty_name = algorithm_flavour + " - CAREamics"
60
60
 
61
61
  description = [f"# {algorithm_pretty_name}\n\n"]
@@ -1,7 +1,7 @@
1
1
  """Module use to build BMZ model description."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import List, Optional, Tuple, Union
4
+ from typing import Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  from bioimageio.spec._internal.io import resolve_and_extract
@@ -28,17 +28,17 @@ from bioimageio.spec.model.v0_5 import (
28
28
  WeightsDescr,
29
29
  )
30
30
 
31
- from careamics.config import Configuration, DataConfig
31
+ from careamics.config import Configuration, GeneralDataConfig
32
32
 
33
33
  from ._readme_factory import readme_factory
34
34
 
35
35
 
36
36
  def _create_axes(
37
37
  array: np.ndarray,
38
- data_config: DataConfig,
39
- channel_names: Optional[List[str]] = None,
38
+ data_config: GeneralDataConfig,
39
+ channel_names: Optional[list[str]] = None,
40
40
  is_input: bool = True,
41
- ) -> List[AxisBase]:
41
+ ) -> list[AxisBase]:
42
42
  """Create axes description.
43
43
 
44
44
  Array shape is expected to be SC(Z)YX.
@@ -49,15 +49,15 @@ def _create_axes(
49
49
  Array.
50
50
  data_config : DataModel
51
51
  CAREamics data configuration.
52
- channel_names : Optional[List[str]], optional
52
+ channel_names : Optional[list[str]], optional
53
53
  Channel names, by default None.
54
54
  is_input : bool, optional
55
55
  Whether the axes are input axes, by default True.
56
56
 
57
57
  Returns
58
58
  -------
59
- List[AxisBase]
60
- List of axes description.
59
+ list[AxisBase]
60
+ list of axes description.
61
61
 
62
62
  Raises
63
63
  ------
@@ -102,11 +102,11 @@ def _create_axes(
102
102
  def _create_inputs_ouputs(
103
103
  input_array: np.ndarray,
104
104
  output_array: np.ndarray,
105
- data_config: DataConfig,
105
+ data_config: GeneralDataConfig,
106
106
  input_path: Union[Path, str],
107
107
  output_path: Union[Path, str],
108
- channel_names: Optional[List[str]] = None,
109
- ) -> Tuple[InputTensorDescr, OutputTensorDescr]:
108
+ channel_names: Optional[list[str]] = None,
109
+ ) -> tuple[InputTensorDescr, OutputTensorDescr]:
110
110
  """Create input and output tensor description.
111
111
 
112
112
  Input and output paths must point to a `.npy` file.
@@ -123,12 +123,12 @@ def _create_inputs_ouputs(
123
123
  Path to input .npy file.
124
124
  output_path : Union[Path, str]
125
125
  Path to output .npy file.
126
- channel_names : Optional[List[str]], optional
126
+ channel_names : Optional[list[str]], optional
127
127
  Channel names, by default None.
128
128
 
129
129
  Returns
130
130
  -------
131
- Tuple[InputTensorDescr, OutputTensorDescr]
131
+ tuple[InputTensorDescr, OutputTensorDescr]
132
132
  Input and output tensor descriptions.
133
133
  """
134
134
  input_axes = _create_axes(input_array, data_config, channel_names)
@@ -188,7 +188,7 @@ def create_model_description(
188
188
  name: str,
189
189
  general_description: str,
190
190
  data_description: str,
191
- authors: List[Author],
191
+ authors: list[Author],
192
192
  inputs: Union[Path, str],
193
193
  outputs: Union[Path, str],
194
194
  weights_path: Union[Path, str],
@@ -197,7 +197,7 @@ def create_model_description(
197
197
  config_path: Union[Path, str],
198
198
  env_path: Union[Path, str],
199
199
  covers: list[Union[Path, str]],
200
- channel_names: Optional[List[str]] = None,
200
+ channel_names: Optional[list[str]] = None,
201
201
  model_version: str = "0.1.0",
202
202
  ) -> ModelDescr:
203
203
  """Create model description.
@@ -212,7 +212,7 @@ def create_model_description(
212
212
  General description of the model.
213
213
  data_description : str
214
214
  Description of the data the model was trained on.
215
- authors : List[Author]
215
+ authors : list[Author]
216
216
  Authors of the model.
217
217
  inputs : Union[Path, str]
218
218
  Path to input .npy file.
@@ -230,7 +230,7 @@ def create_model_description(
230
230
  Path to environment file.
231
231
  covers : list of pathlib.Path or str
232
232
  Paths to cover images.
233
- channel_names : Optional[List[str]], optional
233
+ channel_names : Optional[list[str]], optional
234
234
  Channel names, by default None.
235
235
  model_version : str, default "0.1.0"
236
236
  Model version.
@@ -2,7 +2,7 @@
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import List, Optional, Tuple, Union
5
+ from typing import Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import pkg_resources
@@ -87,11 +87,11 @@ def export_to_bmz(
87
87
  model_name: str,
88
88
  general_description: str,
89
89
  data_description: str,
90
- authors: List[dict],
90
+ authors: list[dict],
91
91
  input_array: np.ndarray,
92
92
  output_array: np.ndarray,
93
93
  covers: Optional[list[Union[Path, str]]] = None,
94
- channel_names: Optional[List[str]] = None,
94
+ channel_names: Optional[list[str]] = None,
95
95
  model_version: str = "0.1.0",
96
96
  ) -> None:
97
97
  """Export the model to BioImage Model Zoo format.
@@ -115,7 +115,7 @@ def export_to_bmz(
115
115
  General description of the model.
116
116
  data_description : str
117
117
  Description of the data the model was trained on.
118
- authors : List[dict]
118
+ authors : list[dict]
119
119
  Authors of the model.
120
120
  input_array : np.ndarray
121
121
  Input array, should not have been normalized.
@@ -123,24 +123,13 @@ def export_to_bmz(
123
123
  Output array, should have been denormalized.
124
124
  covers : list of pathlib.Path or str, default=None
125
125
  Paths to the cover images.
126
- channel_names : Optional[List[str]], optional
126
+ channel_names : Optional[list[str]], optional
127
127
  Channel names, by default None.
128
128
  model_version : str, default="0.1.0"
129
129
  Model version.
130
-
131
- Raises
132
- ------
133
- ValueError
134
- If the model is a Custom model.
135
130
  """
136
131
  path_to_archive = Path(path_to_archive)
137
132
 
138
- # method is not compatible with Custom models
139
- if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
140
- raise ValueError(
141
- "Exporting Custom models to BioImage Model Zoo format is not supported."
142
- )
143
-
144
133
  if path_to_archive.suffix != ".zip":
145
134
  raise ValueError(
146
135
  f"Path to archive must point to a zip file, got {path_to_archive}."
@@ -212,7 +201,7 @@ def export_to_bmz(
212
201
 
213
202
  def load_from_bmz(
214
203
  path: Union[Path, str, HttpUrl]
215
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
204
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
216
205
  """Load a model from a BioImage Model Zoo archive.
217
206
 
218
207
  Parameters
@@ -1,11 +1,11 @@
1
1
  """Utility functions to load pretrained models."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Tuple, Union
4
+ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from careamics.config import Configuration
8
+ from careamics.config import Configuration, configuration_factory
9
9
  from careamics.lightning.lightning_module import FCNModule, VAEModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
@@ -13,7 +13,7 @@ from careamics.utils import check_path_exists
13
13
 
14
14
  def load_pretrained(
15
15
  path: Union[Path, str]
16
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
16
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
17
17
  """
18
18
  Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
19
19
 
@@ -26,8 +26,8 @@ def load_pretrained(
26
26
 
27
27
  Returns
28
28
  -------
29
- Tuple[CAREamicsKiln, Configuration]
30
- Tuple of CAREamics model and its configuration.
29
+ tuple[CAREamicsKiln, Configuration]
30
+ tuple of CAREamics model and its configuration.
31
31
 
32
32
  Raises
33
33
  ------
@@ -48,7 +48,7 @@ def load_pretrained(
48
48
 
49
49
  def _load_checkpoint(
50
50
  path: Union[Path, str]
51
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
51
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
52
52
  """
53
53
  Load a model from a checkpoint and return both model and configuration.
54
54
 
@@ -59,8 +59,8 @@ def _load_checkpoint(
59
59
 
60
60
  Returns
61
61
  -------
62
- Tuple[CAREamicsKiln, Configuration]
63
- Tuple of CAREamics model and its configuration.
62
+ tuple[CAREamicsKiln, Configuration]
63
+ tuple of CAREamics model and its configuration.
64
64
 
65
65
  Raises
66
66
  ------
@@ -92,4 +92,4 @@ def _load_checkpoint(
92
92
  f"{cfg_dict['algorithm_config']['model']['architecture']}"
93
93
  )
94
94
 
95
- return model, Configuration(**cfg_dict)
95
+ return model, configuration_factory(cfg_dict)