careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -2,29 +2,35 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union, Callable
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Optional, Union
6
7
 
7
8
  import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
8
11
 
9
12
  from .utils.empty_patch_fetcher import EmptyPatchFetcher
10
13
  from .utils.index_manager import GridIndexManager
11
14
  from .utils.index_switcher import IndexSwitcher
12
- from .config import DatasetConfig
15
+ from .config import MicroSplitDataConfig
13
16
  from .types import DataSplitType, TilingMode
14
17
 
15
18
 
16
- class MultiChDloader:
19
+ class MultiChDloader(Dataset):
20
+ """Multi-channel dataset loader."""
21
+
17
22
  def __init__(
18
23
  self,
19
- data_config: DatasetConfig,
20
- fpath: str,
21
- load_data_fn: Callable,
22
- val_fraction: float = None,
23
- test_fraction: float = None,
24
+ data_config: MicroSplitDataConfig,
25
+ datapath: Union[str, Path],
26
+ load_data_fn: Optional[Callable] = None,
27
+ val_fraction: float = 0.1,
28
+ test_fraction: float = 0.1,
29
+ allow_generation: bool = False,
24
30
  ):
25
31
  """ """
26
32
  self._data_type = data_config.data_type
27
- self._fpath = fpath
33
+ self._fpath = datapath
28
34
  self._data = self._noise_data = None
29
35
  self.Z = 1
30
36
  self._5Ddata = False
@@ -395,7 +401,7 @@ class MultiChDloader:
395
401
  )
396
402
 
397
403
  def get_idx_manager_shapes(
398
- self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
404
+ self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
399
405
  ):
400
406
  numC = self._data.shape[-1]
401
407
  if self._5Ddata:
@@ -415,7 +421,7 @@ class MultiChDloader:
415
421
 
416
422
  return patch_shape, grid_shape
417
423
 
418
- def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
424
+ def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
419
425
  """
420
426
  If one wants to change the image size on the go, then this can be used.
421
427
  Args:
@@ -519,7 +525,7 @@ class MultiChDloader:
519
525
  },
520
526
  )
521
527
 
522
- def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
528
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
523
529
  if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
524
530
  # In training, this is used.
525
531
  # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
@@ -600,7 +606,7 @@ class MultiChDloader:
600
606
  return new_img
601
607
 
602
608
  def _crop_flip_img(
603
- self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
609
+ self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
604
610
  ):
605
611
  new_img = self._crop_img(img, patch_start_loc)
606
612
  if h_flip:
@@ -611,8 +617,8 @@ class MultiChDloader:
611
617
  return new_img.astype(np.float32)
612
618
 
613
619
  def _load_img(
614
- self, index: Union[int, Tuple[int, int]]
615
- ) -> Tuple[np.ndarray, np.ndarray]:
620
+ self, index: Union[int, tuple[int, int]]
621
+ ) -> tuple[np.ndarray, np.ndarray]:
616
622
  """
617
623
  Returns the channels and also the respective noise channels.
618
624
  """
@@ -806,7 +812,7 @@ class MultiChDloader:
806
812
  w_start = 0
807
813
  return h_start, w_start
808
814
 
809
- def _get_img(self, index: Union[int, Tuple[int, int]]):
815
+ def _get_img(self, index: Union[int, tuple[int, int]]):
810
816
  """
811
817
  Loads an image.
812
818
  Crops the image such that cropped image has content.
@@ -1056,8 +1062,8 @@ class MultiChDloader:
1056
1062
  return img_tuples, noise_tuples
1057
1063
 
1058
1064
  def __getitem__(
1059
- self, index: Union[int, Tuple[int, int]]
1060
- ) -> Tuple[np.ndarray, np.ndarray]:
1065
+ self, index: Union[int, tuple[int, int]]
1066
+ ) -> tuple[np.ndarray, np.ndarray]:
1061
1067
  # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1062
1068
 
1063
1069
  if self._train_index_switcher is not None:
@@ -4,7 +4,7 @@ from typing import Callable, Union
4
4
  import numpy as np
5
5
  from numpy.typing import NDArray
6
6
 
7
- from .config import DatasetConfig
7
+ from .config import MicroSplitDataConfig
8
8
  from .lc_dataset import LCMultiChDloader
9
9
  from .multich_dataset import MultiChDloader
10
10
  from .types import DataSplitType
@@ -82,7 +82,7 @@ class SingleFileLCDset(LCMultiChDloader):
82
82
  def __init__(
83
83
  self,
84
84
  preloaded_data: NDArray,
85
- data_config: DatasetConfig,
85
+ data_config: MicroSplitDataConfig,
86
86
  fpath: str,
87
87
  load_data_fn: Callable,
88
88
  val_fraction=None,
@@ -106,7 +106,7 @@ class SingleFileLCDset(LCMultiChDloader):
106
106
 
107
107
  def load_data(
108
108
  self,
109
- data_config: DatasetConfig,
109
+ data_config: MicroSplitDataConfig,
110
110
  datasplit_type: DataSplitType,
111
111
  load_data_fn: Callable,
112
112
  val_fraction=None,
@@ -124,7 +124,7 @@ class SingleFileDset(MultiChDloader):
124
124
  def __init__(
125
125
  self,
126
126
  preloaded_data: NDArray,
127
- data_config: DatasetConfig,
127
+ data_config: MicroSplitDataConfig,
128
128
  fpath: str,
129
129
  load_data_fn: Callable,
130
130
  val_fraction=None,
@@ -148,7 +148,7 @@ class SingleFileDset(MultiChDloader):
148
148
 
149
149
  def load_data(
150
150
  self,
151
- data_config: DatasetConfig,
151
+ data_config: MicroSplitDataConfig,
152
152
  datasplit_type: DataSplitType,
153
153
  load_data_fn: Callable[..., NDArray],
154
154
  val_fraction=None,
@@ -175,7 +175,7 @@ class MultiFileDset:
175
175
 
176
176
  def __init__(
177
177
  self,
178
- data_config: DatasetConfig,
178
+ data_config: MicroSplitDataConfig,
179
179
  fpath: str,
180
180
  load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
181
181
  val_fraction=None,
@@ -1,10 +1,10 @@
1
1
  """Module use to build BMZ model description."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Optional, Union
4
+ from typing import Union
5
5
 
6
6
  import numpy as np
7
- from bioimageio.spec._internal.io import resolve_and_extract
7
+ from bioimageio.spec._internal.io import extract
8
8
  from bioimageio.spec.model.v0_5 import (
9
9
  ArchitectureFromLibraryDescr,
10
10
  Author,
@@ -12,7 +12,6 @@ from bioimageio.spec.model.v0_5 import (
12
12
  AxisId,
13
13
  BatchAxis,
14
14
  ChannelAxis,
15
- EnvironmentFileDescr,
16
15
  FileDescr,
17
16
  FixedZeroMeanUnitVarianceAlongAxisKwargs,
18
17
  FixedZeroMeanUnitVarianceDescr,
@@ -36,7 +35,7 @@ from ._readme_factory import readme_factory
36
35
  def _create_axes(
37
36
  array: np.ndarray,
38
37
  data_config: DataConfig,
39
- channel_names: Optional[list[str]] = None,
38
+ channel_names: list[str] | None = None,
40
39
  is_input: bool = True,
41
40
  ) -> list[AxisBase]:
42
41
  """Create axes description.
@@ -105,7 +104,7 @@ def _create_inputs_ouputs(
105
104
  data_config: DataConfig,
106
105
  input_path: Union[Path, str],
107
106
  output_path: Union[Path, str],
108
- channel_names: Optional[list[str]] = None,
107
+ channel_names: list[str] | None = None,
109
108
  ) -> tuple[InputTensorDescr, OutputTensorDescr]:
110
109
  """Create input and output tensor description.
111
110
 
@@ -197,7 +196,7 @@ def create_model_description(
197
196
  config_path: Union[Path, str],
198
197
  env_path: Union[Path, str],
199
198
  covers: list[Union[Path, str]],
200
- channel_names: Optional[list[str]] = None,
199
+ channel_names: list[str] | None = None,
201
200
  model_version: str = "0.1.0",
202
201
  ) -> ModelDescr:
203
202
  """Create model description.
@@ -269,7 +268,7 @@ def create_model_description(
269
268
  source=weights_path,
270
269
  architecture=architecture_descr,
271
270
  pytorch_version=Version(torch_version),
272
- dependencies=EnvironmentFileDescr(source=env_path),
271
+ dependencies=FileDescr(source=Path(env_path)),
273
272
  ),
274
273
  )
275
274
 
@@ -322,9 +321,11 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
322
321
  """
323
322
  if model_desc.weights.pytorch_state_dict is None:
324
323
  raise ValueError("No model weights found in model description.")
325
- weights_path = resolve_and_extract(
326
- model_desc.weights.pytorch_state_dict.source
327
- ).path
324
+
325
+ # extract the zip model and return the directory
326
+ model_dir = extract(model_desc.root)
327
+
328
+ weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
328
329
 
329
330
  for file in model_desc.attachments:
330
331
  file_path = file.source if isinstance(file.source, Path) else file.source.path
@@ -332,7 +333,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
332
333
  continue
333
334
  file_path = Path(file_path)
334
335
  if file_path.name == "careamics.yaml":
335
- config_path = resolve_and_extract(file.source).path
336
+ config_path = model_dir.joinpath(file.source.path)
336
337
  break
337
338
  else:
338
339
  raise ValueError("Configuration file not found.")
@@ -2,7 +2,7 @@
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import Optional, Union
5
+ from typing import Union
6
6
 
7
7
  import numpy as np
8
8
  from bioimageio.core import load_model_description, test_model
@@ -90,8 +90,8 @@ def export_to_bmz(
90
90
  authors: list[dict],
91
91
  input_array: np.ndarray,
92
92
  output_array: np.ndarray,
93
- covers: Optional[list[Union[Path, str]]] = None,
94
- channel_names: Optional[list[str]] = None,
93
+ covers: list[Union[Path, str]] | None = None,
94
+ channel_names: list[str] | None = None,
95
95
  model_version: str = "0.1.0",
96
96
  ) -> None:
97
97
  """Export the model to BioImage Model Zoo format.
@@ -186,11 +186,15 @@ def export_to_bmz(
186
186
  )
187
187
 
188
188
  # test model description
189
- test_kwargs = (
190
- model_description.config.get("bioimageio", {})
191
- .get("test_kwargs", {})
192
- .get("pytorch_state_dict", {})
193
- )
189
+ test_kwargs = {}
190
+ if hasattr(model_description, "config") and isinstance(
191
+ model_description.config, dict
192
+ ):
193
+ bioimageio_config = model_description.config.get("bioimageio", {})
194
+ test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
195
+ "pytorch_state_dict", {}
196
+ )
197
+
194
198
  summary: ValidationSummary = test_model(model_description, **test_kwargs)
195
199
  if summary.status == "failed":
196
200
  raise ValueError(f"Model description test failed: {summary}")
@@ -4,7 +4,7 @@ Layer module.
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
6
 
7
- from typing import Optional, Union
7
+ from typing import Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -207,8 +207,8 @@ def get_pascal_kernel_1d(
207
207
  kernel_size: int,
208
208
  norm: bool = False,
209
209
  *,
210
- device: Optional[torch.device] = None,
211
- dtype: Optional[torch.dtype] = None,
210
+ device: torch.device | None = None,
211
+ dtype: torch.dtype | None = None,
212
212
  ) -> torch.Tensor:
213
213
  """Generate Yang Hui triangle (Pascal's triangle) for a given number.
214
214
 
@@ -270,8 +270,8 @@ def _get_pascal_kernel_nd(
270
270
  norm: bool = True,
271
271
  dim: int = 2,
272
272
  *,
273
- device: Optional[torch.device] = None,
274
- dtype: Optional[torch.dtype] = None,
273
+ device: torch.device | None = None,
274
+ dtype: torch.dtype | None = None,
275
275
  ) -> torch.Tensor:
276
276
  """Generate pascal filter kernel by kernel size.
277
277
 
@@ -54,12 +54,8 @@ def likelihood_factory(
54
54
  )
55
55
  elif isinstance(config, NMLikelihoodConfig):
56
56
  return NoiseModelLikelihood(
57
- data_mean=config.data_mean,
58
- data_std=config.data_std,
59
57
  noise_model=noise_model,
60
58
  )
61
- else:
62
- raise ValueError(f"Invalid likelihood model type: {config.model_type}")
63
59
 
64
60
 
65
61
  # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
@@ -290,27 +286,40 @@ class NoiseModelLikelihood(LikelihoodModule):
290
286
 
291
287
  def __init__(
292
288
  self,
293
- data_mean: Union[np.ndarray, torch.Tensor],
294
- data_std: Union[np.ndarray, torch.Tensor],
295
289
  noise_model: NoiseModel,
296
290
  ):
297
291
  """Constructor.
298
292
 
299
293
  Parameters
300
294
  ----------
301
- data_mean: Union[np.ndarray, torch.Tensor]
302
- The mean of the data, used to unnormalize data for noise model evaluation.
303
- data_std: Union[np.ndarray, torch.Tensor]
304
- The standard deviation of the data, used to unnormalize data for noise
305
- model evaluation.
306
295
  noiseModel: NoiseModel
307
296
  The noise model instance used to compute the likelihood.
308
297
  """
309
298
  super().__init__()
310
- self.data_mean = torch.Tensor(data_mean)
311
- self.data_std = torch.Tensor(data_std)
299
+ self.data_mean = None
300
+ self.data_std = None
312
301
  self.noiseModel = noise_model
313
302
 
303
+ def set_data_stats(
304
+ self,
305
+ data_mean: Union[np.ndarray, torch.Tensor],
306
+ data_std: Union[np.ndarray, torch.Tensor],
307
+ ) -> None:
308
+ """Set the data mean and std for denormalization.
309
+ # TODO check this !!
310
+ Parameters
311
+ ----------
312
+ data_mean : Union[np.ndarray, torch.Tensor]
313
+ Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
314
+ data_std : Union[np.ndarray, torch.Tensor]
315
+ Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
316
+ """
317
+ # Convert to tensor if needed
318
+ self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
319
+ self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
320
+
321
+ # TODO add extra dim for 3D ?
322
+
314
323
  def _set_params_to_same_device_as(
315
324
  self, correct_device_tensor: torch.Tensor
316
325
  ) -> None:
@@ -321,7 +330,10 @@ class NoiseModelLikelihood(LikelihoodModule):
321
330
  correct_device_tensor: torch.Tensor
322
331
  The tensor whose device is used to set the parameters.
323
332
  """
324
- if self.data_mean.device != correct_device_tensor.device:
333
+ if (
334
+ self.data_mean is not None
335
+ and self.data_mean.device != correct_device_tensor.device
336
+ ):
325
337
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
326
338
  self.data_std = self.data_std.to(correct_device_tensor.device)
327
339
  if correct_device_tensor.device != self.noiseModel.device:
@@ -367,6 +379,10 @@ class NoiseModelLikelihood(LikelihoodModule):
367
379
  torch.Tensor
368
380
  The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
369
381
  """
382
+ if self.data_mean is None or self.data_std is None:
383
+ raise RuntimeError(
384
+ "NoiseModelLikelihood: data_mean and data_std must be set before calling log_likelihood."
385
+ )
370
386
  self._set_params_to_same_device_as(x)
371
387
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
372
388
  x_denormalized = x * self.data_std + self.data_mean
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
6
6
  """
7
7
 
8
8
  from collections.abc import Iterable
9
- from typing import Optional, Union
9
+ from typing import Union
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -835,7 +835,7 @@ class LadderVAE(nn.Module):
835
835
  top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
836
836
  return top_layer_shape
837
837
 
838
- def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
838
+ def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
839
839
  """Should be called if we want to predict for a different input/output size."""
840
840
  self.mode_pred = True
841
841
  if tile_size is None:
@@ -3,10 +3,10 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
- from numpy.typing import NDArray
7
6
  import numpy as np
8
7
  import torch
9
8
  import torch.nn as nn
9
+ from numpy.typing import NDArray
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
@@ -355,16 +355,16 @@ class GaussianMixtureNoiseModel(nn.Module):
355
355
 
356
356
  Parameters
357
357
  ----------
358
- x: Tensor
359
- Observations
360
- mean: Tensor
361
- Mean
362
- std: Tensor
363
- Standard-deviation
358
+ x: torch.Tensor
359
+ The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
360
+ mean: torch.Tensor
361
+ The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
362
+ std: torch.Tensor
363
+ The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
364
364
 
365
365
  Returns
366
366
  -------
367
- tmp: Tensor
367
+ tmp: torch.Tensor
368
368
  Normal probability density of `x` given `mean` and `std`
369
369
  """
370
370
  tmp = -((x - mean) ** 2)
@@ -382,9 +382,9 @@ class GaussianMixtureNoiseModel(nn.Module):
382
382
  Parameters
383
383
  ----------
384
384
  observations : Tensor
385
- Noisy observations
385
+ Noisy observations. Shape is (batch, 1, dim1, dim2).
386
386
  signals : Tensor
387
- Underlying signals
387
+ Underlying signals. Shape is (batch, 1, dim1, dim2).
388
388
 
389
389
  Returns
390
390
  -------
@@ -392,15 +392,21 @@ class GaussianMixtureNoiseModel(nn.Module):
392
392
  Likelihood of observations given the signals and the GMM noise model
393
393
  """
394
394
  gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
395
- p = 0
395
+ p = torch.zeros_like(observations)
396
396
  for gaussian in range(self.n_gaussian):
397
+ # Ensure all tensors have compatible shapes
398
+ mean = gaussian_parameters[gaussian]
399
+ std = gaussian_parameters[self.n_gaussian + gaussian]
400
+ weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
401
+
402
+ # Compute normal density
397
403
  p += (
398
404
  self.normal_density(
399
405
  observations,
400
- gaussian_parameters[gaussian],
401
- gaussian_parameters[self.n_gaussian + gaussian],
406
+ mean,
407
+ std,
402
408
  )
403
- * gaussian_parameters[2 * self.n_gaussian + gaussian]
409
+ * weight
404
410
  )
405
411
  return p + self.tolerance
406
412
 
@@ -2,9 +2,15 @@
2
2
 
3
3
  __all__ = [
4
4
  "convert_outputs",
5
+ "convert_outputs_microsplit",
5
6
  "stitch_prediction",
6
7
  "stitch_prediction_single",
8
+ "stitch_prediction_vae",
7
9
  ]
8
10
 
9
- from .prediction_outputs import convert_outputs
10
- from .stitch_prediction import stitch_prediction, stitch_prediction_single
11
+ from .prediction_outputs import convert_outputs, convert_outputs_microsplit
12
+ from .stitch_prediction import (
13
+ stitch_prediction,
14
+ stitch_prediction_single,
15
+ stitch_prediction_vae,
16
+ )
@@ -1,6 +1,6 @@
1
1
  """Module containing pytorch implementations for obtaining predictions from an LVAE."""
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  import torch
6
6
 
@@ -18,7 +18,7 @@ def lvae_predict_single_sample(
18
18
  model: LVAE,
19
19
  likelihood_obj: LikelihoodModule,
20
20
  input: torch.Tensor,
21
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
21
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
22
22
  """
23
23
  Generate a single sample prediction from an LVAE model, for a given input.
24
24
 
@@ -57,7 +57,7 @@ def lvae_predict_tiled_batch(
57
57
  model: LVAE,
58
58
  likelihood_obj: LikelihoodModule,
59
59
  input: tuple[Any],
60
- ) -> tuple[tuple[Any], Optional[tuple[Any]]]:
60
+ ) -> tuple[tuple[Any], tuple[Any] | None]:
61
61
  # TODO: fix docstring return types, ... too many output options
62
62
  """
63
63
  Generate a single sample prediction from an LVAE model, for a given input.
@@ -98,7 +98,7 @@ def lvae_predict_mmse_tiled_batch(
98
98
  likelihood_obj: LikelihoodModule,
99
99
  input: tuple[Any],
100
100
  mmse_count: int,
101
- ) -> tuple[tuple[Any], tuple[Any], Optional[tuple[Any]]]:
101
+ ) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
102
102
  # TODO: fix docstring return types, ... hard to make readable
103
103
  """
104
104
  Generate the MMSE (minimum mean squared error) prediction, for a given input.
@@ -137,7 +137,7 @@ def lvae_predict_mmse_tiled_batch(
137
137
 
138
138
  input_shape = x.shape
139
139
  output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
140
- log_var: Optional[torch.Tensor] = None
140
+ log_var: torch.Tensor | None = None
141
141
  # pre-declare empty array to fill with individual sample predictions
142
142
  sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
143
143
  for mmse_idx in range(mmse_count):
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from ..config.tile_information import TileInformation
9
- from .stitch_prediction import stitch_prediction
9
+ from .stitch_prediction import stitch_prediction, stitch_prediction_vae
10
10
 
11
11
 
12
12
  def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
@@ -41,6 +41,48 @@ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
41
41
  return predictions_output
42
42
 
43
43
 
44
+ def convert_outputs_microsplit(
45
+ predictions: list[tuple[NDArray, NDArray]], dataset
46
+ ) -> tuple[NDArray, NDArray]:
47
+ """
48
+ Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
49
+
50
+ This function processes microsplit predictions that return (tile_prediction, tile_std) tuples
51
+ and stitches them back together using the same logic as get_single_file_mmse.
52
+
53
+ Parameters
54
+ ----------
55
+ predictions : list of tuple[NDArray, NDArray]
56
+ Predictions from Lightning trainer for microsplit. Each element is a tuple of
57
+ (tile_prediction, tile_std) where both are numpy arrays from predict_step.
58
+ dataset : Dataset
59
+ The dataset object used for prediction, needed for stitching function selection
60
+ and stitching process.
61
+
62
+ Returns
63
+ -------
64
+ tuple[NDArray, NDArray]
65
+ A tuple of (stitched_predictions, stitched_stds) representing the full
66
+ stitched predictions and standard deviations.
67
+ """
68
+ if len(predictions) == 0:
69
+ raise ValueError("No predictions provided")
70
+
71
+ # Separate predictions and stds from the list of tuples
72
+ tile_predictions = [pred for pred, _ in predictions]
73
+ tile_stds = [std for _, std in predictions]
74
+
75
+ # Concatenate all tiles exactly like get_single_file_mmse
76
+ tiles_arr = np.concatenate(tile_predictions, axis=0)
77
+ tile_stds_arr = np.concatenate(tile_stds, axis=0)
78
+
79
+ # Apply stitching using stitch_predictions_new
80
+ stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
81
+ stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
82
+
83
+ return stitched_predictions, stitched_stds
84
+
85
+
44
86
  # for mypy
45
87
  @overload
46
88
  def combine_batches( # numpydoc ignore=GL08
@@ -68,6 +110,8 @@ def combine_batches(
68
110
  """
69
111
  If predictions are in batches, they will be combined.
70
112
 
113
+ # TODO improve description!
114
+
71
115
  Parameters
72
116
  ----------
73
117
  predictions : list
@@ -107,11 +151,12 @@ def _combine_tiled_batches(
107
151
  """
108
152
  # turn list of lists into single list
109
153
  tile_infos = [
110
- tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
154
+ tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
111
155
  ]
112
156
  prediction_tiles: list[NDArray] = _combine_array_batches(
113
- [preds for preds, _ in predictions]
157
+ [preds for preds, *_ in predictions]
114
158
  )
159
+
115
160
  return prediction_tiles, tile_infos
116
161
 
117
162