careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ from typing import Callable, Union
4
4
  import numpy as np
5
5
  from numpy.typing import NDArray
6
6
 
7
- from .config import DatasetConfig
7
+ from .config import MicroSplitDataConfig
8
8
  from .lc_dataset import LCMultiChDloader
9
9
  from .multich_dataset import MultiChDloader
10
10
  from .types import DataSplitType
@@ -82,7 +82,7 @@ class SingleFileLCDset(LCMultiChDloader):
82
82
  def __init__(
83
83
  self,
84
84
  preloaded_data: NDArray,
85
- data_config: DatasetConfig,
85
+ data_config: MicroSplitDataConfig,
86
86
  fpath: str,
87
87
  load_data_fn: Callable,
88
88
  val_fraction=None,
@@ -106,7 +106,7 @@ class SingleFileLCDset(LCMultiChDloader):
106
106
 
107
107
  def load_data(
108
108
  self,
109
- data_config: DatasetConfig,
109
+ data_config: MicroSplitDataConfig,
110
110
  datasplit_type: DataSplitType,
111
111
  load_data_fn: Callable,
112
112
  val_fraction=None,
@@ -124,7 +124,7 @@ class SingleFileDset(MultiChDloader):
124
124
  def __init__(
125
125
  self,
126
126
  preloaded_data: NDArray,
127
- data_config: DatasetConfig,
127
+ data_config: MicroSplitDataConfig,
128
128
  fpath: str,
129
129
  load_data_fn: Callable,
130
130
  val_fraction=None,
@@ -148,7 +148,7 @@ class SingleFileDset(MultiChDloader):
148
148
 
149
149
  def load_data(
150
150
  self,
151
- data_config: DatasetConfig,
151
+ data_config: MicroSplitDataConfig,
152
152
  datasplit_type: DataSplitType,
153
153
  load_data_fn: Callable[..., NDArray],
154
154
  val_fraction=None,
@@ -175,7 +175,7 @@ class MultiFileDset:
175
175
 
176
176
  def __init__(
177
177
  self,
178
- data_config: DatasetConfig,
178
+ data_config: MicroSplitDataConfig,
179
179
  fpath: str,
180
180
  load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
181
181
  val_fraction=None,
@@ -32,7 +32,7 @@ class TilingMode:
32
32
  ShiftBoundary = 2
33
33
 
34
34
 
35
- # ------------------------------------------------------------------------------------------------
35
+ # ------------------------------------------------------------------------------------
36
36
  # Function of plotting: TODO -> moved them to another file, plot_utils.py
37
37
  def clean_ax(ax):
38
38
  """
@@ -68,7 +68,9 @@ def get_psnr_str(tar_hsnr, pred, col_idx):
68
68
  """
69
69
  Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
70
70
  """
71
- return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
71
+ psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
72
+
73
+ return f"{psnr:.1f}"
72
74
 
73
75
 
74
76
  def add_psnr_str(ax_, psnr):
@@ -129,8 +131,10 @@ def show_for_one(
129
131
  baseline_preds=None,
130
132
  ):
131
133
  """
132
- Given an index, it plots the input, target, reconstructed images and the difference image.
133
- Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.
134
+ Given an index, it plots the input, target, reconstructed images and the difference
135
+ image.
136
+ Note the the difference image is computed with respect to a ground truth image,
137
+ obtained from the high SNR dataset.
134
138
  """
135
139
  highsnr_val_dset.set_img_sz(patch_size, 64)
136
140
  highsnr_val_dset.disable_noise()
@@ -164,7 +168,8 @@ def plot_crops(
164
168
  for i in range(len(baseline_preds)):
165
169
  if baseline_preds[i].shape != tar_hsnr.shape:
166
170
  print(
167
- f"Baseline prediction {i} shape {baseline_preds[i].shape} does not match target shape {tar_hsnr.shape}"
171
+ f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
172
+ f"match target shape {tar_hsnr.shape}"
168
173
  )
169
174
  print("This happens when we want to predict the edges of the image.")
170
175
  return
@@ -333,14 +338,21 @@ def plot_crops(
333
338
  ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
334
339
  clean_ax(ax_temp)
335
340
 
336
- # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-', label='$C_1$')
337
- # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-', label='$C_2$')
338
- # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-', label='Pred')
339
- # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='--', label='$C^N_1$')
340
- # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='--', label='$C^N_2$')
341
- # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred], loc='upper right', frameon=False, labelcolor='white',
341
+ # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
342
+ # label='$C_1$')
343
+ # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
344
+ # label='$C_2$')
345
+ # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
346
+ # label='Pred')
347
+ # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
348
+ # linestyle='--', label='$C^N_1$')
349
+ # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
350
+ # linestyle='--', label='$C^N_2$')
351
+ # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
352
+ # loc='upper right', frameon=False, labelcolor='white',
342
353
  # prop={'size': 11})
343
- # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred], loc='upper right', frameon=False, labelcolor='white',
354
+ # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
355
+ # loc='upper right', frameon=False, labelcolor='white',
344
356
  # prop={'size': 11})
345
357
 
346
358
  if calibration_stats is not None:
@@ -383,7 +395,9 @@ def plot_calibration(ax, calibration_stats):
383
395
 
384
396
  def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
385
397
  """
386
- Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib
398
+ Adapted from
399
+ https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
400
+ matplotlib
387
401
 
388
402
  Function to offset the "center" of a colormap. Useful for
389
403
  data with a negative min and positive max and you want the
@@ -444,7 +458,8 @@ def get_fractional_change(target, prediction, max_val=None):
444
458
 
445
459
  def get_zero_centered_midval(error):
446
460
  """
447
- When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))
461
+ When done this way, the midval ensures that the colorbar is centered at 0. (Don't
462
+ know how, but it works ;))
448
463
  """
449
464
  vmax = error.max()
450
465
  vmin = error.min()
@@ -670,7 +685,7 @@ def get_single_file_mmse(
670
685
  return stitched_predictions, stitched_stds
671
686
 
672
687
 
673
- # ------------------------------------------------------------------------------------------
688
+ # ---------------------------------------------------------------------------------
674
689
  ### Classes and Functions used to stitch predictions
675
690
  class PatchLocation:
676
691
  """
@@ -698,9 +713,10 @@ def _get_location(extra_padding, hwt, pred_h, pred_w):
698
713
 
699
714
  def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
700
715
  """
701
- For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies.
702
- Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction.
703
- Which time frame, which spatial location (h_start, h_end, w_start,w_end)
716
+ For a given idx of the dataset, it returns where exactly in the dataset, does this
717
+ prediction lies. Note that this prediction also has padded pixels and so a subset of
718
+ it will be used in the final prediction. Which time frame, which spatial location
719
+ (h_start, h_end, w_start,w_end)
704
720
  Args:
705
721
  dset:
706
722
  dset_input_idx:
@@ -785,9 +801,11 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
785
801
  # NOTE: don't need to compute it for every patch.
786
802
  assert (
787
803
  smoothening_pixelcount == 0
788
- ), "For smoothing,enable the get_smoothing_mask. It is disabled since I don't use it and it needs modification to work with non-square images"
804
+ ), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
805
+ "don't use it and it needs modification to work with non-square images"
789
806
  mask = 1
790
- # mask = _get_smoothing_mask(cropped_pred_i.shape, smoothening_pixelcount, loc, frame_size)
807
+ # mask = _get_smoothing_mask(cropped_pred_i.shape,
808
+ # smoothening_pixelcount, loc, frame_size)
791
809
 
792
810
  cropped_pred_list.append(cropped_pred_i)
793
811
 
@@ -827,7 +845,8 @@ def stitch_predictions_new(predictions, dset):
827
845
  output = np.zeros(shape, dtype=predictions.dtype)
828
846
  # frame_shape = dset.get_data_shape()[:-1]
829
847
  for dset_idx in range(predictions.shape[0]):
830
- # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
848
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
849
+ # predictions.shape[-1])
831
850
  # grid start, grid end
832
851
  gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
833
852
  ge = gs + mng.grid_shape
@@ -843,7 +862,8 @@ def stitch_predictions_new(predictions, dset):
843
862
  vgs = np.array([max(0, x) for x in gs], dtype=int)
844
863
  vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
845
864
  # assert np.all(vgs == gs)
846
- # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
865
+ # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
866
+ # to dig why it's failing at this point !
847
867
  # print('VGS')
848
868
  # print(gs)
849
869
  # print(ge)
@@ -898,7 +918,8 @@ def stitch_predictions_general(predictions, dset):
898
918
  # frame_shape = dset.get_data_shape()[:-1]
899
919
  for patch_idx in range(predictions.shape[0]):
900
920
  # grid start, grid end
901
- # channel_idx is 0 because during prediction we're only use one channel. # TODO revisit this
921
+ # channel_idx is 0 because during prediction we're only use one channel.
922
+ # # TODO revisit this
902
923
  # 0th dimension is sample index in the output list
903
924
  grid_coords = np.array(
904
925
  mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
@@ -906,7 +927,8 @@ def stitch_predictions_general(predictions, dset):
906
927
  )
907
928
  sample_idx = grid_coords[0]
908
929
  grid_start = grid_coords[1:]
909
- # from here on, coordinates are relative to the sample(file in the list of inputs)
930
+ # from here on, coordinates are relative to the sample(file in the list of
931
+ # inputs)
910
932
  grid_end = grid_start + mng.grid_shape
911
933
 
912
934
  # patch start, patch end
@@ -186,11 +186,15 @@ def export_to_bmz(
186
186
  )
187
187
 
188
188
  # test model description
189
- test_kwargs = (
190
- model_description.config.bioimageio.model_dump()
191
- .get("test_kwargs", {})
192
- .get("pytorch_state_dict", {})
193
- )
189
+ test_kwargs = {}
190
+ if hasattr(model_description, "config") and isinstance(
191
+ model_description.config, dict
192
+ ):
193
+ bioimageio_config = model_description.config.get("bioimageio", {})
194
+ test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
195
+ "pytorch_state_dict", {}
196
+ )
197
+
194
198
  summary: ValidationSummary = test_model(model_description, **test_kwargs)
195
199
  if summary.status == "failed":
196
200
  raise ValueError(f"Model description test failed: {summary}")
@@ -54,12 +54,8 @@ def likelihood_factory(
54
54
  )
55
55
  elif isinstance(config, NMLikelihoodConfig):
56
56
  return NoiseModelLikelihood(
57
- data_mean=config.data_mean,
58
- data_std=config.data_std,
59
57
  noise_model=noise_model,
60
58
  )
61
- else:
62
- raise ValueError(f"Invalid likelihood model type: {config.model_type}")
63
59
 
64
60
 
65
61
  # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
@@ -290,27 +286,40 @@ class NoiseModelLikelihood(LikelihoodModule):
290
286
 
291
287
  def __init__(
292
288
  self,
293
- data_mean: Union[np.ndarray, torch.Tensor],
294
- data_std: Union[np.ndarray, torch.Tensor],
295
289
  noise_model: NoiseModel,
296
290
  ):
297
291
  """Constructor.
298
292
 
299
293
  Parameters
300
294
  ----------
301
- data_mean: Union[np.ndarray, torch.Tensor]
302
- The mean of the data, used to unnormalize data for noise model evaluation.
303
- data_std: Union[np.ndarray, torch.Tensor]
304
- The standard deviation of the data, used to unnormalize data for noise
305
- model evaluation.
306
295
  noiseModel: NoiseModel
307
296
  The noise model instance used to compute the likelihood.
308
297
  """
309
298
  super().__init__()
310
- self.data_mean = torch.Tensor(data_mean)
311
- self.data_std = torch.Tensor(data_std)
299
+ self.data_mean = None
300
+ self.data_std = None
312
301
  self.noiseModel = noise_model
313
302
 
303
+ def set_data_stats(
304
+ self,
305
+ data_mean: Union[np.ndarray, torch.Tensor],
306
+ data_std: Union[np.ndarray, torch.Tensor],
307
+ ) -> None:
308
+ """Set the data mean and std for denormalization.
309
+ # TODO check this !!
310
+ Parameters
311
+ ----------
312
+ data_mean : Union[np.ndarray, torch.Tensor]
313
+ Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
314
+ data_std : Union[np.ndarray, torch.Tensor]
315
+ Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
316
+ """
317
+ # Convert to tensor if needed
318
+ self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
319
+ self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
320
+
321
+ # TODO add extra dim for 3D ?
322
+
314
323
  def _set_params_to_same_device_as(
315
324
  self, correct_device_tensor: torch.Tensor
316
325
  ) -> None:
@@ -321,7 +330,10 @@ class NoiseModelLikelihood(LikelihoodModule):
321
330
  correct_device_tensor: torch.Tensor
322
331
  The tensor whose device is used to set the parameters.
323
332
  """
324
- if self.data_mean.device != correct_device_tensor.device:
333
+ if (
334
+ self.data_mean is not None
335
+ and self.data_mean.device != correct_device_tensor.device
336
+ ):
325
337
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
326
338
  self.data_std = self.data_std.to(correct_device_tensor.device)
327
339
  if correct_device_tensor.device != self.noiseModel.device:
@@ -367,6 +379,11 @@ class NoiseModelLikelihood(LikelihoodModule):
367
379
  torch.Tensor
368
380
  The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
369
381
  """
382
+ if self.data_mean is None or self.data_std is None:
383
+ raise RuntimeError(
384
+ "NoiseModelLikelihood: data_mean and data_std must be set before"
385
+ "callinglog_likelihood."
386
+ )
370
387
  self._set_params_to_same_device_as(x)
371
388
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
372
389
  x_denormalized = x * self.data_std + self.data_mean
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
6
6
  """
7
7
 
8
8
  from collections.abc import Iterable
9
- from typing import Optional, Union
9
+ from typing import Union
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -835,7 +835,7 @@ class LadderVAE(nn.Module):
835
835
  top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
836
836
  return top_layer_shape
837
837
 
838
- def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
838
+ def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
839
839
  """Should be called if we want to predict for a different input/output size."""
840
840
  self.mode_pred = True
841
841
  if tile_size is None:
@@ -3,10 +3,10 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
- from numpy.typing import NDArray
7
6
  import numpy as np
8
7
  import torch
9
8
  import torch.nn as nn
9
+ from numpy.typing import NDArray
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
@@ -355,16 +355,16 @@ class GaussianMixtureNoiseModel(nn.Module):
355
355
 
356
356
  Parameters
357
357
  ----------
358
- x: Tensor
359
- Observations
360
- mean: Tensor
361
- Mean
362
- std: Tensor
363
- Standard-deviation
358
+ x: torch.Tensor
359
+ The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
360
+ mean: torch.Tensor
361
+ The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
362
+ std: torch.Tensor
363
+ The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
364
364
 
365
365
  Returns
366
366
  -------
367
- tmp: Tensor
367
+ tmp: torch.Tensor
368
368
  Normal probability density of `x` given `mean` and `std`
369
369
  """
370
370
  tmp = -((x - mean) ** 2)
@@ -382,9 +382,9 @@ class GaussianMixtureNoiseModel(nn.Module):
382
382
  Parameters
383
383
  ----------
384
384
  observations : Tensor
385
- Noisy observations
385
+ Noisy observations. Shape is (batch, 1, dim1, dim2).
386
386
  signals : Tensor
387
- Underlying signals
387
+ Underlying signals. Shape is (batch, 1, dim1, dim2).
388
388
 
389
389
  Returns
390
390
  -------
@@ -392,15 +392,21 @@ class GaussianMixtureNoiseModel(nn.Module):
392
392
  Likelihood of observations given the signals and the GMM noise model
393
393
  """
394
394
  gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
395
- p = 0
395
+ p = torch.zeros_like(observations)
396
396
  for gaussian in range(self.n_gaussian):
397
+ # Ensure all tensors have compatible shapes
398
+ mean = gaussian_parameters[gaussian]
399
+ std = gaussian_parameters[self.n_gaussian + gaussian]
400
+ weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
401
+
402
+ # Compute normal density
397
403
  p += (
398
404
  self.normal_density(
399
405
  observations,
400
- gaussian_parameters[gaussian],
401
- gaussian_parameters[self.n_gaussian + gaussian],
406
+ mean,
407
+ std,
402
408
  )
403
- * gaussian_parameters[2 * self.n_gaussian + gaussian]
409
+ * weight
404
410
  )
405
411
  return p + self.tolerance
406
412
 
@@ -2,9 +2,15 @@
2
2
 
3
3
  __all__ = [
4
4
  "convert_outputs",
5
+ "convert_outputs_microsplit",
5
6
  "stitch_prediction",
6
7
  "stitch_prediction_single",
8
+ "stitch_prediction_vae",
7
9
  ]
8
10
 
9
- from .prediction_outputs import convert_outputs
10
- from .stitch_prediction import stitch_prediction, stitch_prediction_single
11
+ from .prediction_outputs import convert_outputs, convert_outputs_microsplit
12
+ from .stitch_prediction import (
13
+ stitch_prediction,
14
+ stitch_prediction_single,
15
+ stitch_prediction_vae,
16
+ )
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from ..config.tile_information import TileInformation
9
- from .stitch_prediction import stitch_prediction
9
+ from .stitch_prediction import stitch_prediction, stitch_prediction_vae
10
10
 
11
11
 
12
12
  def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
@@ -41,6 +41,49 @@ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
41
41
  return predictions_output
42
42
 
43
43
 
44
+ def convert_outputs_microsplit(
45
+ predictions: list[tuple[NDArray, NDArray]], dataset
46
+ ) -> tuple[NDArray, NDArray]:
47
+ """
48
+ Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
49
+
50
+ This function processes microsplit predictions that return (tile_prediction,
51
+ tile_std) tuples and stitches them back together using the same logic as
52
+ get_single_file_mmse.
53
+
54
+ Parameters
55
+ ----------
56
+ predictions : list of tuple[NDArray, NDArray]
57
+ Predictions from Lightning trainer for microsplit. Each element is a tuple of
58
+ (tile_prediction, tile_std) where both are numpy arrays from predict_step.
59
+ dataset : Dataset
60
+ The dataset object used for prediction, needed for stitching function selection
61
+ and stitching process.
62
+
63
+ Returns
64
+ -------
65
+ tuple[NDArray, NDArray]
66
+ A tuple of (stitched_predictions, stitched_stds) representing the full
67
+ stitched predictions and standard deviations.
68
+ """
69
+ if len(predictions) == 0:
70
+ raise ValueError("No predictions provided")
71
+
72
+ # Separate predictions and stds from the list of tuples
73
+ tile_predictions = [pred for pred, _ in predictions]
74
+ tile_stds = [std for _, std in predictions]
75
+
76
+ # Concatenate all tiles exactly like get_single_file_mmse
77
+ tiles_arr = np.concatenate(tile_predictions, axis=0)
78
+ tile_stds_arr = np.concatenate(tile_stds, axis=0)
79
+
80
+ # Apply stitching using stitch_predictions_new
81
+ stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
82
+ stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
83
+
84
+ return stitched_predictions, stitched_stds
85
+
86
+
44
87
  # for mypy
45
88
  @overload
46
89
  def combine_batches( # numpydoc ignore=GL08
@@ -68,6 +111,8 @@ def combine_batches(
68
111
  """
69
112
  If predictions are in batches, they will be combined.
70
113
 
114
+ # TODO improve description!
115
+
71
116
  Parameters
72
117
  ----------
73
118
  predictions : list
@@ -107,11 +152,12 @@ def _combine_tiled_batches(
107
152
  """
108
153
  # turn list of lists into single list
109
154
  tile_infos = [
110
- tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
155
+ tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
111
156
  ]
112
157
  prediction_tiles: list[NDArray] = _combine_array_batches(
113
- [preds for preds, _ in predictions]
158
+ [preds for preds, *_ in predictions]
114
159
  )
160
+
115
161
  return prediction_tiles, tile_infos
116
162
 
117
163
 
@@ -9,6 +9,86 @@ from numpy.typing import NDArray
9
9
  from careamics.config.tile_information import TileInformation
10
10
 
11
11
 
12
+ class TilingMode:
13
+ """Enum for the tiling mode."""
14
+
15
+ TrimBoundary = 0
16
+ PadBoundary = 1
17
+ ShiftBoundary = 2
18
+
19
+
20
+ def stitch_prediction_vae(predictions, dset) -> NDArray:
21
+ """
22
+ Stitch predictions back together using dataset's index manager.
23
+
24
+ Parameters
25
+ ----------
26
+ predictions : np.ndarray
27
+ Array of predictions with shape (n_tiles, channels, height, width).
28
+ dset : Dataset
29
+ Dataset object with idx_manager containing tiling information.
30
+
31
+ Returns
32
+ -------
33
+ np.ndarray
34
+ Stitched array with shape matching the original data shape.
35
+ """
36
+ mng = dset.idx_manager
37
+
38
+ # if there are more channels, use all of them.
39
+ shape = list(dset.get_data_shape())
40
+ shape[-1] = max(shape[-1], predictions.shape[1])
41
+
42
+ output = np.zeros(shape, dtype=predictions.dtype)
43
+ # frame_shape = dset.get_data_shape()[:-1]
44
+ for dset_idx in range(predictions.shape[0]):
45
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
46
+ # predictions.shape[-1])
47
+ # grid start, grid end
48
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
49
+ ge = gs + mng.grid_shape
50
+
51
+ # patch start, patch end
52
+ ps = gs - mng.patch_offset()
53
+ pe = ps + mng.patch_shape
54
+
55
+ # valid grid start, valid grid end
56
+ vgs = np.array([max(0, x) for x in gs], dtype=int)
57
+ vge = np.array(
58
+ [min(x, y) for x, y in zip(ge, mng.data_shape, strict=False)], dtype=int
59
+ )
60
+
61
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
62
+ for dim in range(len(vgs)):
63
+ if ps[dim] == 0:
64
+ vgs[dim] = 0
65
+ if pe[dim] == mng.data_shape[dim]:
66
+ vge[dim] = mng.data_shape[dim]
67
+
68
+ # relative start, relative end. This will be used on pred_tiled
69
+ rs = vgs - ps
70
+ re = rs + (vge - vgs)
71
+
72
+ for ch_idx in range(predictions.shape[1]):
73
+ if len(output.shape) == 4:
74
+ # channel dimension is the last one.
75
+ output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
76
+ predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
77
+ )
78
+ elif len(output.shape) == 5:
79
+ # channel dimension is the last one.
80
+ assert vge[0] - vgs[0] == 1, "Only one frame is supported"
81
+ output[
82
+ vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
83
+ ] = predictions[dset_idx][
84
+ ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
85
+ ]
86
+ else:
87
+ raise ValueError(f"Unsupported shape {output.shape}")
88
+
89
+ return output
90
+
91
+
12
92
  # TODO: why not allow input and output of torch.tensor ?
13
93
  def stitch_prediction(
14
94
  tiles: list[np.ndarray],
@@ -107,6 +187,8 @@ def stitch_prediction_single(
107
187
 
108
188
  # Insert cropped tile into predicted image using stitch coordinates
109
189
  image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
110
- predicted_image[image_slices] = cropped_tile.astype(np.float32)
190
+
191
+ # TODO fix mypy error here, potentially due to numpy 2
192
+ predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
111
193
 
112
194
  return predicted_image
@@ -74,7 +74,7 @@ class XYRandomRotate90(Transform):
74
74
  return patch, target, additional_arrays
75
75
 
76
76
  # number of rotations
77
- n_rot = self.rng.integers(1, 4)
77
+ n_rot = int(self.rng.integers(1, 4))
78
78
 
79
79
  axes = (-2, -1)
80
80
  patch_transformed = self._apply(patch, n_rot, axes)
@@ -21,9 +21,9 @@ def get_careamics_version() -> str:
21
21
  parts = __version__.split(".")
22
22
 
23
23
  # for local installs that do not detect the latest versions via tags
24
- # (typically our CI will install `0.1.devX<hash>` versions)
25
- if "dev" in parts[-1]:
26
- parts[-1] = "*"
24
+ # (typically our CI will install `0.X.devX<hash>.<other hash>` versions)
25
+ if "dev" in parts[2]:
26
+ parts[2] = "*"
27
27
  clean_version = ".".join(parts[:3])
28
28
 
29
29
  logger.warning(
@@ -34,5 +34,5 @@ def get_careamics_version() -> str:
34
34
  f"closest CAREamics version from PyPI or conda-forge."
35
35
  )
36
36
 
37
- # Remove any local version identifier)
37
+ # Remove any local version identifier
38
38
  return ".".join(parts[:3])