careamics 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +49 -49
  2. careamics/cli/conf.py +6 -6
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/algorithms/vae_algorithm_model.py +4 -4
  6. careamics/config/callback_model.py +8 -8
  7. careamics/config/configuration_factories.py +49 -49
  8. careamics/config/data/data_model.py +7 -13
  9. careamics/config/data/ng_data_model.py +8 -14
  10. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  11. careamics/config/inference_model.py +6 -10
  12. careamics/config/likelihood_model.py +2 -2
  13. careamics/config/nm_model.py +5 -7
  14. careamics/config/training_model.py +4 -4
  15. careamics/config/transformations/normalize_model.py +3 -3
  16. careamics/config/transformations/xy_flip_model.py +2 -2
  17. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  18. careamics/config/validators/validator_utils.py +1 -2
  19. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  20. careamics/dataset/in_memory_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +1 -2
  22. careamics/dataset/patching/random_patching.py +6 -6
  23. careamics/dataset/patching/sequential_patching.py +4 -4
  24. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  25. careamics/dataset_ng/dataset.py +3 -3
  26. careamics/dataset_ng/factory.py +19 -19
  27. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  28. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  29. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  30. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  31. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  32. careamics/lightning/dataset_ng/data_module.py +43 -43
  33. careamics/lightning/lightning_module.py +12 -14
  34. careamics/lightning/predict_data_module.py +8 -8
  35. careamics/lightning/train_data_module.py +11 -11
  36. careamics/losses/lvae/losses.py +9 -9
  37. careamics/model_io/bioimage/model_description.py +12 -11
  38. careamics/model_io/bmz_io.py +4 -4
  39. careamics/models/layers.py +5 -5
  40. careamics/prediction_utils/lvae_prediction.py +5 -5
  41. careamics/transforms/compose.py +9 -9
  42. careamics/transforms/n2v_manipulate.py +3 -3
  43. careamics/transforms/n2v_manipulate_torch.py +4 -4
  44. careamics/transforms/normalize.py +4 -6
  45. careamics/transforms/pixel_manipulation.py +6 -8
  46. careamics/transforms/pixel_manipulation_torch.py +5 -7
  47. careamics/transforms/xy_flip.py +3 -5
  48. careamics/transforms/xy_random_rotate90.py +3 -5
  49. careamics/utils/logging.py +8 -8
  50. careamics/utils/metrics.py +2 -2
  51. careamics/utils/plotting.py +1 -3
  52. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
  53. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/RECORD +56 -56
  54. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
  55. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
  56. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Sequence
4
4
  from pathlib import Path
5
- from typing import Any, Optional, Protocol, Union
5
+ from typing import Any, Protocol, Union
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
@@ -25,7 +25,7 @@ class WriteStrategy(Protocol):
25
25
  trainer: Trainer,
26
26
  pl_module: LightningModule,
27
27
  prediction: Any, # TODO: change to expected type
28
- batch_indices: Optional[Sequence[int]],
28
+ batch_indices: Sequence[int] | None,
29
29
  batch: Any, # TODO: change to expected type
30
30
  batch_idx: int,
31
31
  dataloader_idx: int,
@@ -133,7 +133,7 @@ class CacheTiles(WriteStrategy):
133
133
  trainer: Trainer,
134
134
  pl_module: LightningModule,
135
135
  prediction: tuple[NDArray, list[TileInformation]],
136
- batch_indices: Optional[Sequence[int]],
136
+ batch_indices: Sequence[int] | None,
137
137
  batch: tuple[NDArray, list[TileInformation]],
138
138
  batch_idx: int,
139
139
  dataloader_idx: int,
@@ -259,7 +259,7 @@ class WriteTilesZarr(WriteStrategy):
259
259
  trainer: Trainer,
260
260
  pl_module: LightningModule,
261
261
  prediction: Any,
262
- batch_indices: Optional[Sequence[int]],
262
+ batch_indices: Sequence[int] | None,
263
263
  batch: Any,
264
264
  batch_idx: int,
265
265
  dataloader_idx: int,
@@ -346,7 +346,7 @@ class WriteImage(WriteStrategy):
346
346
  trainer: Trainer,
347
347
  pl_module: LightningModule,
348
348
  prediction: NDArray,
349
- batch_indices: Optional[Sequence[int]],
349
+ batch_indices: Sequence[int] | None,
350
350
  batch: NDArray,
351
351
  batch_idx: int,
352
352
  dataloader_idx: int,
@@ -1,6 +1,6 @@
1
1
  """Module containing convenience function to create `WriteStrategy`."""
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  from careamics.config.support import SupportedData
6
6
  from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
@@ -11,9 +11,9 @@ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
11
11
  def create_write_strategy(
12
12
  write_type: SupportedWriteType,
13
13
  tiled: bool,
14
- write_func: Optional[WriteFunc] = None,
15
- write_extension: Optional[str] = None,
16
- write_func_kwargs: Optional[dict[str, Any]] = None,
14
+ write_func: WriteFunc | None = None,
15
+ write_extension: str | None = None,
16
+ write_func_kwargs: dict[str, Any] | None = None,
17
17
  ) -> WriteStrategy:
18
18
  """
19
19
  Create a write strategy from convenient parameters.
@@ -78,8 +78,8 @@ def create_write_strategy(
78
78
 
79
79
  def _create_tiled_write_strategy(
80
80
  write_type: SupportedWriteType,
81
- write_func: Optional[WriteFunc],
82
- write_extension: Optional[str],
81
+ write_func: WriteFunc | None,
82
+ write_extension: str | None,
83
83
  write_func_kwargs: dict[str, Any],
84
84
  ) -> WriteStrategy:
85
85
  """
@@ -130,7 +130,7 @@ def _create_tiled_write_strategy(
130
130
 
131
131
 
132
132
  def select_write_func(
133
- write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
133
+ write_type: SupportedWriteType, write_func: WriteFunc | None = None
134
134
  ) -> WriteFunc:
135
135
  """
136
136
  Return a function to write images.
@@ -177,7 +177,7 @@ def select_write_func(
177
177
 
178
178
 
179
179
  def select_write_extension(
180
- write_type: SupportedWriteType, write_extension: Optional[str] = None
180
+ write_type: SupportedWriteType, write_extension: str | None = None
181
181
  ) -> str:
182
182
  """
183
183
  Return an extension to add to file paths.
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Optional, Union, overload
5
+ from typing import Any, Union, overload
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -124,14 +124,14 @@ class CareamicsDataModule(L.LightningDataModule):
124
124
  self,
125
125
  data_config: NGDataConfig,
126
126
  *,
127
- train_data: Optional[InputType] = None,
128
- train_data_target: Optional[InputType] = None,
129
- val_data: Optional[InputType] = None,
130
- val_data_target: Optional[InputType] = None,
131
- pred_data: Optional[InputType] = None,
132
- pred_data_target: Optional[InputType] = None,
127
+ train_data: InputType | None = None,
128
+ train_data_target: InputType | None = None,
129
+ val_data: InputType | None = None,
130
+ val_data_target: InputType | None = None,
131
+ pred_data: InputType | None = None,
132
+ pred_data_target: InputType | None = None,
133
133
  extension_filter: str = "",
134
- val_percentage: Optional[float] = None,
134
+ val_percentage: float | None = None,
135
135
  val_minimum_split: int = 5,
136
136
  use_in_memory: bool = True,
137
137
  ) -> None: ...
@@ -142,16 +142,16 @@ class CareamicsDataModule(L.LightningDataModule):
142
142
  self,
143
143
  data_config: NGDataConfig,
144
144
  *,
145
- train_data: Optional[InputType] = None,
146
- train_data_target: Optional[InputType] = None,
147
- val_data: Optional[InputType] = None,
148
- val_data_target: Optional[InputType] = None,
149
- pred_data: Optional[InputType] = None,
150
- pred_data_target: Optional[InputType] = None,
145
+ train_data: InputType | None = None,
146
+ train_data_target: InputType | None = None,
147
+ val_data: InputType | None = None,
148
+ val_data_target: InputType | None = None,
149
+ pred_data: InputType | None = None,
150
+ pred_data_target: InputType | None = None,
151
151
  read_source_func: Callable,
152
- read_kwargs: Optional[dict[str, Any]] = None,
152
+ read_kwargs: dict[str, Any] | None = None,
153
153
  extension_filter: str = "",
154
- val_percentage: Optional[float] = None,
154
+ val_percentage: float | None = None,
155
155
  val_minimum_split: int = 5,
156
156
  use_in_memory: bool = True,
157
157
  ) -> None: ...
@@ -161,16 +161,16 @@ class CareamicsDataModule(L.LightningDataModule):
161
161
  self,
162
162
  data_config: NGDataConfig,
163
163
  *,
164
- train_data: Optional[Any] = None,
165
- train_data_target: Optional[Any] = None,
166
- val_data: Optional[Any] = None,
167
- val_data_target: Optional[Any] = None,
168
- pred_data: Optional[Any] = None,
169
- pred_data_target: Optional[Any] = None,
164
+ train_data: Any | None = None,
165
+ train_data_target: Any | None = None,
166
+ val_data: Any | None = None,
167
+ val_data_target: Any | None = None,
168
+ pred_data: Any | None = None,
169
+ pred_data_target: Any | None = None,
170
170
  image_stack_loader: ImageStackLoader,
171
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
171
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
172
172
  extension_filter: str = "",
173
- val_percentage: Optional[float] = None,
173
+ val_percentage: float | None = None,
174
174
  val_minimum_split: int = 5,
175
175
  use_in_memory: bool = True,
176
176
  ) -> None: ...
@@ -179,18 +179,18 @@ class CareamicsDataModule(L.LightningDataModule):
179
179
  self,
180
180
  data_config: NGDataConfig,
181
181
  *,
182
- train_data: Optional[Any] = None,
183
- train_data_target: Optional[Any] = None,
184
- val_data: Optional[Any] = None,
185
- val_data_target: Optional[Any] = None,
186
- pred_data: Optional[Any] = None,
187
- pred_data_target: Optional[Any] = None,
188
- read_source_func: Optional[Callable] = None,
189
- read_kwargs: Optional[dict[str, Any]] = None,
190
- image_stack_loader: Optional[ImageStackLoader] = None,
191
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
182
+ train_data: Any | None = None,
183
+ train_data_target: Any | None = None,
184
+ val_data: Any | None = None,
185
+ val_data_target: Any | None = None,
186
+ pred_data: Any | None = None,
187
+ pred_data_target: Any | None = None,
188
+ read_source_func: Callable | None = None,
189
+ read_kwargs: dict[str, Any] | None = None,
190
+ image_stack_loader: ImageStackLoader | None = None,
191
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
192
192
  extension_filter: str = "",
193
- val_percentage: Optional[float] = None,
193
+ val_percentage: float | None = None,
194
194
  val_minimum_split: int = 5,
195
195
  use_in_memory: bool = True,
196
196
  ) -> None:
@@ -280,7 +280,7 @@ class CareamicsDataModule(L.LightningDataModule):
280
280
  def _validate_input_target_type_consistency(
281
281
  self,
282
282
  input_data: InputType,
283
- target_data: Optional[InputType],
283
+ target_data: InputType | None,
284
284
  ) -> None:
285
285
  """Validate if the input and target data types are consistent.
286
286
 
@@ -314,7 +314,7 @@ class CareamicsDataModule(L.LightningDataModule):
314
314
  self,
315
315
  input_data,
316
316
  target_data=None,
317
- ) -> tuple[list[Path], Optional[list[Path]]]:
317
+ ) -> tuple[list[Path], list[Path] | None]:
318
318
  """List files from input and target directories.
319
319
 
320
320
  Parameters
@@ -347,7 +347,7 @@ class CareamicsDataModule(L.LightningDataModule):
347
347
  self,
348
348
  input_data,
349
349
  target_data=None,
350
- ) -> tuple[list[Path], Optional[list[Path]]]:
350
+ ) -> tuple[list[Path], list[Path] | None]:
351
351
  """Create a list of file paths from the input and target data.
352
352
 
353
353
  Parameters
@@ -379,7 +379,7 @@ class CareamicsDataModule(L.LightningDataModule):
379
379
  def _validate_array_input(
380
380
  self,
381
381
  input_data: InputType,
382
- target_data: Optional[InputType],
382
+ target_data: InputType | None,
383
383
  ) -> tuple[Any, Any]:
384
384
  """Validate if the input data is a numpy array.
385
385
 
@@ -408,8 +408,8 @@ class CareamicsDataModule(L.LightningDataModule):
408
408
  )
409
409
 
410
410
  def _validate_path_input(
411
- self, input_data: InputType, target_data: Optional[InputType]
412
- ) -> tuple[list[Path], Optional[list[Path]]]:
411
+ self, input_data: InputType, target_data: InputType | None
412
+ ) -> tuple[list[Path], list[Path] | None]:
413
413
  """Validate if the input data is a path or a list of paths.
414
414
 
415
415
  Parameters
@@ -488,8 +488,8 @@ class CareamicsDataModule(L.LightningDataModule):
488
488
 
489
489
  def _initialize_data_pair(
490
490
  self,
491
- input_data: Optional[InputType],
492
- target_data: Optional[InputType],
491
+ input_data: InputType | None,
492
+ target_data: InputType | None,
493
493
  ) -> tuple[Any, Any]:
494
494
  """
495
495
  Initialize a pair of input and target data.
@@ -1,7 +1,7 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
3
  from collections.abc import Callable
4
- from typing import Any, Literal, Optional, Union
4
+ from typing import Any, Literal, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
@@ -90,7 +90,7 @@ class FCNModule(L.LightningModule):
90
90
  # create preprocessing, model and loss function
91
91
  if isinstance(algorithm_config, N2VAlgorithm):
92
92
  self.use_n2v = True
93
- self.n2v_preprocess: Optional[N2VManipulateTorch] = N2VManipulateTorch(
93
+ self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
94
94
  n2v_manipulate_config=algorithm_config.n2v_config
95
95
  )
96
96
  else:
@@ -333,18 +333,16 @@ class VAEModule(L.LightningModule):
333
333
  self.model: nn.Module = model_factory(self.algorithm_config.model)
334
334
 
335
335
  # create loss function
336
- self.noise_model: Optional[NoiseModel] = noise_model_factory(
336
+ self.noise_model: NoiseModel | None = noise_model_factory(
337
337
  self.algorithm_config.noise_model
338
338
  )
339
339
 
340
- self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
341
- likelihood_factory(
342
- config=self.algorithm_config.noise_model_likelihood,
343
- noise_model=self.noise_model,
344
- )
340
+ self.noise_model_likelihood: NoiseModelLikelihood | None = likelihood_factory(
341
+ config=self.algorithm_config.noise_model_likelihood,
342
+ noise_model=self.noise_model,
345
343
  )
346
344
 
347
- self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
345
+ self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
348
346
  self.algorithm_config.gaussian_likelihood
349
347
  )
350
348
 
@@ -380,7 +378,7 @@ class VAEModule(L.LightningModule):
380
378
 
381
379
  def training_step(
382
380
  self, batch: tuple[Tensor, Tensor], batch_idx: Any
383
- ) -> Optional[dict[str, Tensor]]:
381
+ ) -> dict[str, Tensor] | None:
384
382
  """Training step.
385
383
 
386
384
  Parameters
@@ -603,7 +601,7 @@ class VAEModule(L.LightningModule):
603
601
  for i in range(out_channels)
604
602
  ]
605
603
 
606
- def reduce_running_psnr(self) -> Optional[float]:
604
+ def reduce_running_psnr(self) -> float | None:
607
605
  """Reduce the running PSNR statistics and reset the running PSNR.
608
606
 
609
607
  Returns
@@ -634,11 +632,11 @@ def create_careamics_module(
634
632
  use_n2v2: bool = False,
635
633
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
636
634
  struct_n2v_span: int = 5,
637
- model_parameters: Optional[dict] = None,
635
+ model_parameters: dict | None = None,
638
636
  optimizer: Union[SupportedOptimizer, str] = "Adam",
639
- optimizer_parameters: Optional[dict] = None,
637
+ optimizer_parameters: dict | None = None,
640
638
  lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
641
- lr_scheduler_parameters: Optional[dict] = None,
639
+ lr_scheduler_parameters: dict | None = None,
642
640
  ) -> Union[FCNModule, VAEModule]:
643
641
  """Create a CAREamics Lightning module.
644
642
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -65,9 +65,9 @@ class PredictDataModule(L.LightningDataModule):
65
65
  self,
66
66
  pred_config: InferenceConfig,
67
67
  pred_data: Union[Path, str, NDArray],
68
- read_source_func: Optional[Callable] = None,
68
+ read_source_func: Callable | None = None,
69
69
  extension_filter: str = "",
70
- dataloader_params: Optional[dict] = None,
70
+ dataloader_params: dict | None = None,
71
71
  ) -> None:
72
72
  """
73
73
  Constructor.
@@ -173,7 +173,7 @@ class PredictDataModule(L.LightningDataModule):
173
173
  self.pred_data, self.data_type, self.extension_filter
174
174
  )
175
175
 
176
- def setup(self, stage: Optional[str] = None) -> None:
176
+ def setup(self, stage: str | None = None) -> None:
177
177
  """
178
178
  Hook called at the beginning of predict.
179
179
 
@@ -231,13 +231,13 @@ def create_predict_datamodule(
231
231
  axes: str,
232
232
  image_means: list[float],
233
233
  image_stds: list[float],
234
- tile_size: Optional[tuple[int, ...]] = None,
235
- tile_overlap: Optional[tuple[int, ...]] = None,
234
+ tile_size: tuple[int, ...] | None = None,
235
+ tile_overlap: tuple[int, ...] | None = None,
236
236
  batch_size: int = 1,
237
237
  tta_transforms: bool = True,
238
- read_source_func: Optional[Callable] = None,
238
+ read_source_func: Callable | None = None,
239
239
  extension_filter: str = "",
240
- dataloader_params: Optional[dict] = None,
240
+ dataloader_params: dict | None = None,
241
241
  ) -> PredictDataModule:
242
242
  """Create a CAREamics prediction Lightning datamodule.
243
243
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -121,10 +121,10 @@ class TrainDataModule(L.LightningDataModule):
121
121
  self,
122
122
  data_config: DataConfig,
123
123
  train_data: Union[Path, str, NDArray],
124
- val_data: Optional[Union[Path, str, NDArray]] = None,
125
- train_data_target: Optional[Union[Path, str, NDArray]] = None,
126
- val_data_target: Optional[Union[Path, str, NDArray]] = None,
127
- read_source_func: Optional[Callable] = None,
124
+ val_data: Union[Path, str, NDArray] | None = None,
125
+ train_data_target: Union[Path, str, NDArray] | None = None,
126
+ val_data_target: Union[Path, str, NDArray] | None = None,
127
+ read_source_func: Callable | None = None,
128
128
  extension_filter: str = "",
129
129
  val_percentage: float = 0.1,
130
130
  val_minimum_split: int = 5,
@@ -477,15 +477,15 @@ def create_train_datamodule(
477
477
  patch_size: list[int],
478
478
  axes: str,
479
479
  batch_size: int,
480
- val_data: Optional[Union[str, Path, NDArray]] = None,
481
- transforms: Optional[list[TransformModel]] = None,
482
- train_target_data: Optional[Union[str, Path, NDArray]] = None,
483
- val_target_data: Optional[Union[str, Path, NDArray]] = None,
484
- read_source_func: Optional[Callable] = None,
480
+ val_data: Union[str, Path, NDArray] | None = None,
481
+ transforms: list[TransformModel] | None = None,
482
+ train_target_data: Union[str, Path, NDArray] | None = None,
483
+ val_target_data: Union[str, Path, NDArray] | None = None,
484
+ read_source_func: Callable | None = None,
485
485
  extension_filter: str = "",
486
486
  val_percentage: float = 0.1,
487
487
  val_minimum_patches: int = 5,
488
- dataloader_params: Optional[dict] = None,
488
+ dataloader_params: dict | None = None,
489
489
  use_in_memory: bool = True,
490
490
  ) -> TrainDataModule:
491
491
  """Create a TrainDataModule.
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import torch
@@ -112,7 +112,7 @@ def get_kl_divergence_loss(
112
112
  rescaling: Literal["latent_dim", "image_dim"],
113
113
  aggregation: Literal["mean", "sum"],
114
114
  free_bits_coeff: float,
115
- img_shape: Optional[tuple[int]] = None,
115
+ img_shape: tuple[int] | None = None,
116
116
  ) -> torch.Tensor:
117
117
  """Compute the KL divergence loss.
118
118
 
@@ -273,9 +273,9 @@ def musplit_loss(
273
273
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
274
274
  targets: torch.Tensor,
275
275
  config: LVAELossConfig,
276
- gaussian_likelihood: Optional[GaussianLikelihood],
277
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None, # TODO: ugly
278
- ) -> Optional[dict[str, torch.Tensor]]:
276
+ gaussian_likelihood: GaussianLikelihood | None,
277
+ noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
278
+ ) -> dict[str, torch.Tensor] | None:
279
279
  """Loss function for muSplit.
280
280
 
281
281
  Parameters
@@ -351,9 +351,9 @@ def denoisplit_loss(
351
351
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
352
352
  targets: torch.Tensor,
353
353
  config: LVAELossConfig,
354
- gaussian_likelihood: Optional[GaussianLikelihood] = None,
355
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None,
356
- ) -> Optional[dict[str, torch.Tensor]]:
354
+ gaussian_likelihood: GaussianLikelihood | None = None,
355
+ noise_model_likelihood: NoiseModelLikelihood | None = None,
356
+ ) -> dict[str, torch.Tensor] | None:
357
357
  """Loss function for DenoiSplit.
358
358
 
359
359
  Parameters
@@ -430,7 +430,7 @@ def denoisplit_musplit_loss(
430
430
  config: LVAELossConfig,
431
431
  gaussian_likelihood: GaussianLikelihood,
432
432
  noise_model_likelihood: NoiseModelLikelihood,
433
- ) -> Optional[dict[str, torch.Tensor]]:
433
+ ) -> dict[str, torch.Tensor] | None:
434
434
  """Loss function for DenoiSplit.
435
435
 
436
436
  Parameters
@@ -1,10 +1,10 @@
1
1
  """Module use to build BMZ model description."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Optional, Union
4
+ from typing import Union
5
5
 
6
6
  import numpy as np
7
- from bioimageio.spec._internal.io import resolve_and_extract
7
+ from bioimageio.spec._internal.io import extract
8
8
  from bioimageio.spec.model.v0_5 import (
9
9
  ArchitectureFromLibraryDescr,
10
10
  Author,
@@ -12,7 +12,6 @@ from bioimageio.spec.model.v0_5 import (
12
12
  AxisId,
13
13
  BatchAxis,
14
14
  ChannelAxis,
15
- EnvironmentFileDescr,
16
15
  FileDescr,
17
16
  FixedZeroMeanUnitVarianceAlongAxisKwargs,
18
17
  FixedZeroMeanUnitVarianceDescr,
@@ -36,7 +35,7 @@ from ._readme_factory import readme_factory
36
35
  def _create_axes(
37
36
  array: np.ndarray,
38
37
  data_config: DataConfig,
39
- channel_names: Optional[list[str]] = None,
38
+ channel_names: list[str] | None = None,
40
39
  is_input: bool = True,
41
40
  ) -> list[AxisBase]:
42
41
  """Create axes description.
@@ -105,7 +104,7 @@ def _create_inputs_ouputs(
105
104
  data_config: DataConfig,
106
105
  input_path: Union[Path, str],
107
106
  output_path: Union[Path, str],
108
- channel_names: Optional[list[str]] = None,
107
+ channel_names: list[str] | None = None,
109
108
  ) -> tuple[InputTensorDescr, OutputTensorDescr]:
110
109
  """Create input and output tensor description.
111
110
 
@@ -197,7 +196,7 @@ def create_model_description(
197
196
  config_path: Union[Path, str],
198
197
  env_path: Union[Path, str],
199
198
  covers: list[Union[Path, str]],
200
- channel_names: Optional[list[str]] = None,
199
+ channel_names: list[str] | None = None,
201
200
  model_version: str = "0.1.0",
202
201
  ) -> ModelDescr:
203
202
  """Create model description.
@@ -269,7 +268,7 @@ def create_model_description(
269
268
  source=weights_path,
270
269
  architecture=architecture_descr,
271
270
  pytorch_version=Version(torch_version),
272
- dependencies=EnvironmentFileDescr(source=env_path),
271
+ dependencies=FileDescr(source=Path(env_path)),
273
272
  ),
274
273
  )
275
274
 
@@ -322,9 +321,11 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
322
321
  """
323
322
  if model_desc.weights.pytorch_state_dict is None:
324
323
  raise ValueError("No model weights found in model description.")
325
- weights_path = resolve_and_extract(
326
- model_desc.weights.pytorch_state_dict.source
327
- ).path
324
+
325
+ # extract the zip model and return the directory
326
+ model_dir = extract(model_desc.root)
327
+
328
+ weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
328
329
 
329
330
  for file in model_desc.attachments:
330
331
  file_path = file.source if isinstance(file.source, Path) else file.source.path
@@ -332,7 +333,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
332
333
  continue
333
334
  file_path = Path(file_path)
334
335
  if file_path.name == "careamics.yaml":
335
- config_path = resolve_and_extract(file.source).path
336
+ config_path = model_dir.joinpath(file.source.path)
336
337
  break
337
338
  else:
338
339
  raise ValueError("Configuration file not found.")
@@ -2,7 +2,7 @@
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import Optional, Union
5
+ from typing import Union
6
6
 
7
7
  import numpy as np
8
8
  from bioimageio.core import load_model_description, test_model
@@ -90,8 +90,8 @@ def export_to_bmz(
90
90
  authors: list[dict],
91
91
  input_array: np.ndarray,
92
92
  output_array: np.ndarray,
93
- covers: Optional[list[Union[Path, str]]] = None,
94
- channel_names: Optional[list[str]] = None,
93
+ covers: list[Union[Path, str]] | None = None,
94
+ channel_names: list[str] | None = None,
95
95
  model_version: str = "0.1.0",
96
96
  ) -> None:
97
97
  """Export the model to BioImage Model Zoo format.
@@ -187,7 +187,7 @@ def export_to_bmz(
187
187
 
188
188
  # test model description
189
189
  test_kwargs = (
190
- model_description.config.get("bioimageio", {})
190
+ model_description.config.bioimageio.model_dump()
191
191
  .get("test_kwargs", {})
192
192
  .get("pytorch_state_dict", {})
193
193
  )
@@ -4,7 +4,7 @@ Layer module.
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
6
 
7
- from typing import Optional, Union
7
+ from typing import Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -207,8 +207,8 @@ def get_pascal_kernel_1d(
207
207
  kernel_size: int,
208
208
  norm: bool = False,
209
209
  *,
210
- device: Optional[torch.device] = None,
211
- dtype: Optional[torch.dtype] = None,
210
+ device: torch.device | None = None,
211
+ dtype: torch.dtype | None = None,
212
212
  ) -> torch.Tensor:
213
213
  """Generate Yang Hui triangle (Pascal's triangle) for a given number.
214
214
 
@@ -270,8 +270,8 @@ def _get_pascal_kernel_nd(
270
270
  norm: bool = True,
271
271
  dim: int = 2,
272
272
  *,
273
- device: Optional[torch.device] = None,
274
- dtype: Optional[torch.dtype] = None,
273
+ device: torch.device | None = None,
274
+ dtype: torch.dtype | None = None,
275
275
  ) -> torch.Tensor:
276
276
  """Generate pascal filter kernel by kernel size.
277
277
 
@@ -1,6 +1,6 @@
1
1
  """Module containing pytorch implementations for obtaining predictions from an LVAE."""
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  import torch
6
6
 
@@ -18,7 +18,7 @@ def lvae_predict_single_sample(
18
18
  model: LVAE,
19
19
  likelihood_obj: LikelihoodModule,
20
20
  input: torch.Tensor,
21
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
21
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
22
22
  """
23
23
  Generate a single sample prediction from an LVAE model, for a given input.
24
24
 
@@ -57,7 +57,7 @@ def lvae_predict_tiled_batch(
57
57
  model: LVAE,
58
58
  likelihood_obj: LikelihoodModule,
59
59
  input: tuple[Any],
60
- ) -> tuple[tuple[Any], Optional[tuple[Any]]]:
60
+ ) -> tuple[tuple[Any], tuple[Any] | None]:
61
61
  # TODO: fix docstring return types, ... too many output options
62
62
  """
63
63
  Generate a single sample prediction from an LVAE model, for a given input.
@@ -98,7 +98,7 @@ def lvae_predict_mmse_tiled_batch(
98
98
  likelihood_obj: LikelihoodModule,
99
99
  input: tuple[Any],
100
100
  mmse_count: int,
101
- ) -> tuple[tuple[Any], tuple[Any], Optional[tuple[Any]]]:
101
+ ) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
102
102
  # TODO: fix docstring return types, ... hard to make readable
103
103
  """
104
104
  Generate the MMSE (minimum mean squared error) prediction, for a given input.
@@ -137,7 +137,7 @@ def lvae_predict_mmse_tiled_batch(
137
137
 
138
138
  input_shape = x.shape
139
139
  output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
140
- log_var: Optional[torch.Tensor] = None
140
+ log_var: torch.Tensor | None = None
141
141
  # pre-declare empty array to fill with individual sample predictions
142
142
  sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
143
143
  for mmse_idx in range(mmse_count):