careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  from collections.abc import Sequence
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Any, Generic, Literal, NamedTuple, Optional, Union
4
+ from typing import Any, Generic, Literal, NamedTuple, Union
5
5
 
6
6
  import numpy as np
7
7
  from numpy.typing import NDArray
@@ -51,7 +51,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
51
51
  data_config: NGDataConfig,
52
52
  mode: Mode,
53
53
  input_extractor: PatchExtractor[GenericImageStack],
54
- target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
54
+ target_extractor: PatchExtractor[GenericImageStack] | None = None,
55
55
  ):
56
56
  self.config = data_config
57
57
  self.mode = mode
@@ -115,7 +115,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
115
115
 
116
116
  return patching_strategy
117
117
 
118
- def _initialize_transforms(self) -> Optional[Compose]:
118
+ def _initialize_transforms(self) -> Compose | None:
119
119
  normalize = NormalizeModel(
120
120
  image_means=self.input_stats.means,
121
121
  image_stds=self.input_stats.stds,
@@ -1,7 +1,7 @@
1
1
  from collections.abc import Sequence
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Any, Optional
4
+ from typing import Any
5
5
 
6
6
  from numpy.typing import NDArray
7
7
  from typing_extensions import ParamSpec
@@ -48,8 +48,8 @@ class DatasetType(Enum):
48
48
  def determine_dataset_type(
49
49
  data_type: SupportedData,
50
50
  in_memory: bool,
51
- read_func: Optional[ReadFunc] = None,
52
- image_stack_loader: Optional[ImageStackLoader] = None,
51
+ read_func: ReadFunc | None = None,
52
+ image_stack_loader: ImageStackLoader | None = None,
53
53
  ) -> DatasetType:
54
54
  """Determine what the dataset type should be based on the input arguments.
55
55
 
@@ -121,10 +121,10 @@ def create_dataset(
121
121
  inputs: Any,
122
122
  targets: Any,
123
123
  in_memory: bool,
124
- read_func: Optional[ReadFunc] = None,
125
- read_kwargs: Optional[dict[str, Any]] = None,
126
- image_stack_loader: Optional[ImageStackLoader] = None,
127
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
124
+ read_func: ReadFunc | None = None,
125
+ read_kwargs: dict[str, Any] | None = None,
126
+ image_stack_loader: ImageStackLoader | None = None,
127
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
128
128
  ) -> CareamicsDataset[ImageStack]:
129
129
  """
130
130
  Convenience function to create the CAREamicsDataset.
@@ -201,7 +201,7 @@ def create_array_dataset(
201
201
  config: NGDataConfig,
202
202
  mode: Mode,
203
203
  inputs: Sequence[NDArray[Any]],
204
- targets: Optional[Sequence[NDArray[Any]]],
204
+ targets: Sequence[NDArray[Any]] | None,
205
205
  ) -> CareamicsDataset[InMemoryImageStack]:
206
206
  """
207
207
  Create a CAREamicsDataset from array data.
@@ -223,7 +223,7 @@ def create_array_dataset(
223
223
  A CAREamicsDataset.
224
224
  """
225
225
  input_extractor = create_array_extractor(source=inputs, axes=config.axes)
226
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
226
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
227
227
  if targets is not None:
228
228
  target_extractor = create_array_extractor(source=targets, axes=config.axes)
229
229
  else:
@@ -235,7 +235,7 @@ def create_tiff_dataset(
235
235
  config: NGDataConfig,
236
236
  mode: Mode,
237
237
  inputs: Sequence[Path],
238
- targets: Optional[Sequence[Path]],
238
+ targets: Sequence[Path] | None,
239
239
  ) -> CareamicsDataset[InMemoryImageStack]:
240
240
  """
241
241
  Create a CAREamicsDataset from tiff files that will be all loaded into memory.
@@ -260,7 +260,7 @@ def create_tiff_dataset(
260
260
  source=inputs,
261
261
  axes=config.axes,
262
262
  )
263
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
263
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
264
264
  if targets is not None:
265
265
  target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
266
266
  else:
@@ -273,7 +273,7 @@ def create_czi_dataset(
273
273
  config: NGDataConfig,
274
274
  mode: Mode,
275
275
  inputs: Sequence[Path],
276
- targets: Optional[Sequence[Path]],
276
+ targets: Sequence[Path] | None,
277
277
  ) -> CareamicsDataset[CziImageStack]:
278
278
  """
279
279
  Create a dataset from CZI files.
@@ -296,7 +296,7 @@ def create_czi_dataset(
296
296
  """
297
297
 
298
298
  input_extractor = create_czi_extractor(source=inputs, axes=config.axes)
299
- target_extractor: Optional[PatchExtractor[CziImageStack]]
299
+ target_extractor: PatchExtractor[CziImageStack] | None
300
300
  if targets is not None:
301
301
  target_extractor = create_czi_extractor(source=targets, axes=config.axes)
302
302
  else:
@@ -309,7 +309,7 @@ def create_ome_zarr_dataset(
309
309
  config: NGDataConfig,
310
310
  mode: Mode,
311
311
  inputs: Sequence[Path],
312
- targets: Optional[Sequence[Path]],
312
+ targets: Sequence[Path] | None,
313
313
  ) -> CareamicsDataset[ZarrImageStack]:
314
314
  """
315
315
  Create a dataset from OME ZARR files.
@@ -332,7 +332,7 @@ def create_ome_zarr_dataset(
332
332
  """
333
333
 
334
334
  input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
335
- target_extractor: Optional[PatchExtractor[ZarrImageStack]]
335
+ target_extractor: PatchExtractor[ZarrImageStack] | None
336
336
  if targets is not None:
337
337
  target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
338
338
  else:
@@ -345,7 +345,7 @@ def create_custom_file_dataset(
345
345
  config: NGDataConfig,
346
346
  mode: Mode,
347
347
  inputs: Sequence[Path],
348
- targets: Optional[Sequence[Path]],
348
+ targets: Sequence[Path] | None,
349
349
  *,
350
350
  read_func: ReadFunc,
351
351
  read_kwargs: dict[str, Any],
@@ -378,7 +378,7 @@ def create_custom_file_dataset(
378
378
  input_extractor = create_custom_file_extractor(
379
379
  source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
380
380
  )
381
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
381
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
382
382
  if targets is not None:
383
383
  target_extractor = create_custom_file_extractor(
384
384
  source=targets,
@@ -396,7 +396,7 @@ def create_custom_image_stack_dataset(
396
396
  config: NGDataConfig,
397
397
  mode: Mode,
398
398
  inputs: Any,
399
- targets: Optional[Any],
399
+ targets: Any | None,
400
400
  image_stack_loader: ImageStackLoader[P, GenericImageStack],
401
401
  *args: P.args,
402
402
  **kwargs: P.kwargs,
@@ -436,7 +436,7 @@ def create_custom_image_stack_dataset(
436
436
  *args,
437
437
  **kwargs,
438
438
  )
439
- target_extractor: Optional[PatchExtractor[GenericImageStack]]
439
+ target_extractor: PatchExtractor[GenericImageStack] | None
440
440
  if targets is not None:
441
441
  target_extractor = create_custom_image_stack_extractor(
442
442
  targets,
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
7
  import numpy as np
8
8
  import zarr
9
9
  from numpy.typing import NDArray
10
- from zarr.storage import FSStore
10
+ from zarr.storage import FsspecStore
11
11
 
12
12
  from careamics.config import DataConfig
13
13
  from careamics.config.support import SupportedData
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
20
20
 
21
21
  # %%
22
22
  def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
23
- store = FSStore(url=file_path.resolve())
23
+ store = FsspecStore.from_url(url=file_path.resolve())
24
24
  # create array
25
25
  array = zarr.create(
26
26
  store=store,
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
61
61
  # ### Make sure file exists
62
62
 
63
63
  # %%
64
- store = FSStore(url=file_path.resolve(), mode="r")
64
+ store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
65
65
 
66
66
  # %%
67
67
  list(store.keys())
@@ -72,7 +72,7 @@ list(store.keys())
72
72
 
73
73
  # %%
74
74
  class ZarrSource(TypedDict):
75
- store: FSStore
75
+ store: FsspecStore
76
76
  data_paths: Sequence[str]
77
77
 
78
78
 
@@ -1,9 +1,8 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Literal, Union
3
+ from typing import Any, Literal, Self, Union
4
4
 
5
5
  from numpy.typing import DTypeLike, NDArray
6
- from typing_extensions import Self
7
6
 
8
7
  from careamics.dataset.dataset_utils import reshape_array
9
8
  from careamics.file_io.read import ReadFunc, read_tiff
@@ -1,11 +1,11 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Union
3
+ from typing import Self, Union
4
4
 
5
+ import validators
5
6
  import zarr
6
- import zarr.storage
7
7
  from numpy.typing import NDArray
8
- from typing_extensions import Self
8
+ from zarr.storage import FsspecStore, LocalStore
9
9
 
10
10
  from careamics.dataset.dataset_utils import reshape_array
11
11
 
@@ -15,9 +15,10 @@ class ZarrImageStack:
15
15
  A class for extracting patches from an image stack that is stored as a zarr array.
16
16
  """
17
17
 
18
- # TODO: keeping store type narrow so that it has the path attribute
19
- # base zarr store is zarr.storage.Store, includes MemoryStore
20
- def __init__(self, store: zarr.storage.FSStore, data_path: str, axes: str):
18
+ # TODO: We should keep store type narrow
19
+ # - in zarr v3, does zarr.storage.Store exists and has the path attribute?
20
+ # - can we declare a narrow type rather than a union?
21
+ def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
21
22
  self._store = store
22
23
  self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
23
24
  # TODO: validate axes
@@ -46,8 +47,33 @@ class ZarrImageStack:
46
47
  Assumes the path only contains 1 image.
47
48
 
48
49
  Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
50
+
51
+ Parameters
52
+ ----------
53
+ path : Union[Path, str]
54
+ Path to the root of the OME-Zarr, local file or url.
55
+
56
+ Returns
57
+ -------
58
+ ZarrImageStack
59
+ Initialised ZarrImageStack.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the path does not exist or is not a valid URL.
65
+ ValueError
66
+ If the OME-Zarr at the path does not contain the attribute 'multiscales'.
49
67
  """
50
- store = zarr.storage.FSStore(url=path)
68
+ if Path(path).is_file():
69
+ store = zarr.storage.LocalStore(root=Path(path).resolve())
70
+ elif validators.url(path):
71
+ store = zarr.storage.FsspecStore.from_url(url=path)
72
+ else:
73
+ raise ValueError(
74
+ f"Path '{path}' is neither an existing file nor a valid URL."
75
+ )
76
+
51
77
  group = zarr.open_group(store=store, mode="r")
52
78
  if "multiscales" not in group.attrs:
53
79
  raise ValueError(
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
38
38
 
39
39
  >>> from typing import TypedDict
40
40
 
41
- >>> from zarr.storage import FSStore
41
+ >>> from zarr.storage import FsspecStore
42
42
 
43
43
  >>> from careamics.config import DataConfig
44
44
  >>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
46
46
  >>> # Define a zarr source
47
47
  >>> # It encompasses multiple arguments that determine what data will be loaded
48
48
  >>> class ZarrSource(TypedDict):
49
- ... store: FSStore
49
+ ... store: FsspecStore
50
50
  ... data_paths: Sequence[str]
51
51
 
52
52
  >>> def custom_image_stack_loader(
@@ -1,7 +1,6 @@
1
1
  """A module for random patching strategies."""
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Optional
5
4
 
6
5
  import numpy as np
7
6
 
@@ -31,7 +30,7 @@ class RandomPatchingStrategy:
31
30
  self,
32
31
  data_shapes: Sequence[Sequence[int]],
33
32
  patch_size: Sequence[int],
34
- seed: Optional[int] = None,
33
+ seed: int | None = None,
35
34
  ):
36
35
  """
37
36
  A patching strategy for sampling random patches.
@@ -193,7 +192,7 @@ class FixedRandomPatchingStrategy:
193
192
  self,
194
193
  data_shapes: Sequence[Sequence[int]],
195
194
  patch_size: Sequence[int],
196
- seed: Optional[int] = None,
195
+ seed: int | None = None,
197
196
  ):
198
197
  """A patching strategy for sampling random patches.
199
198
 
@@ -1,6 +1,5 @@
1
1
  import itertools
2
2
  from collections.abc import Sequence
3
- from typing import Optional
4
3
 
5
4
  import numpy as np
6
5
  from typing_extensions import ParamSpec
@@ -18,7 +17,7 @@ class SequentialPatchingStrategy:
18
17
  self,
19
18
  data_shapes: Sequence[Sequence[int]],
20
19
  patch_size: Sequence[int],
21
- overlaps: Optional[Sequence[int]] = None,
20
+ overlaps: Sequence[int] | None = None,
22
21
  ):
23
22
  self.data_shapes = data_shapes
24
23
  self.patch_size = patch_size
@@ -9,4 +9,3 @@ __all__ = [
9
9
 
10
10
  from .get_func import ReadFunc, get_read_func
11
11
  from .tiff import read_tiff
12
- from .zarr import read_zarr
@@ -1,18 +1,32 @@
1
1
  """CAREamics PyTorch Lightning modules."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "FCNModule",
5
6
  "HyperParametersCallback",
7
+ "MicroSplitDataModule",
6
8
  "PredictDataModule",
7
9
  "ProgressBarCallback",
8
10
  "TrainDataModule",
9
11
  "VAEModule",
10
12
  "create_careamics_module",
13
+ "create_microsplit_predict_datamodule",
14
+ "create_microsplit_train_datamodule",
11
15
  "create_predict_datamodule",
12
16
  "create_train_datamodule",
17
+ "create_unet_based_module",
18
+ "create_vae_based_module",
13
19
  ]
14
20
 
15
- from .callbacks import HyperParametersCallback, ProgressBarCallback
21
+ from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
16
22
  from .lightning_module import FCNModule, VAEModule, create_careamics_module
23
+ from .microsplit_data_module import (
24
+ MicroSplitDataModule,
25
+ create_microsplit_predict_datamodule,
26
+ create_microsplit_train_datamodule,
27
+ )
17
28
  from .predict_data_module import PredictDataModule, create_predict_datamodule
18
- from .train_data_module import TrainDataModule, create_train_datamodule
29
+ from .train_data_module import (
30
+ TrainDataModule,
31
+ create_train_datamodule,
32
+ )
@@ -1,11 +1,13 @@
1
1
  """Callbacks module."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "HyperParametersCallback",
5
6
  "PredictionWriterCallback",
6
7
  "ProgressBarCallback",
7
8
  ]
8
9
 
10
+ from .data_stats_callback import DataStatsCallback
9
11
  from .hyperparameters_callback import HyperParametersCallback
10
12
  from .prediction_writer_callback import PredictionWriterCallback
11
13
  from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,23 @@
1
+ """Data statistics callback."""
2
+
3
+ import pytorch_lightning as L
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+
7
+ class DataStatsCallback(Callback):
8
+ """Callback to update model's data statistics from datamodule.
9
+
10
+ This callback ensures that the model has access to the data statistics (mean and std)
11
+ calculated by the datamodule before training starts.
12
+ """
13
+
14
+ def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
15
+ """Called when trainer is setting up."""
16
+ if stage == "fit":
17
+ # Get data statistics from datamodule
18
+ (data_mean, data_std), _ = trainer.datamodule.get_data_stats()
19
+
20
+ # Set data statistics in the model's likelihood module
21
+ module.noise_model_likelihood.set_data_stats(
22
+ data_mean=data_mean["target"], data_std=data_std["target"]
23
+ )
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Sequence
6
6
  from pathlib import Path
7
- from typing import Any, Optional, Union
7
+ from typing import Any, Union
8
8
 
9
9
  from pytorch_lightning import LightningModule, Trainer
10
10
  from pytorch_lightning.callbacks import BasePredictionWriter
@@ -84,9 +84,9 @@ class PredictionWriterCallback(BasePredictionWriter):
84
84
  cls,
85
85
  write_type: SupportedWriteType,
86
86
  tiled: bool,
87
- write_func: Optional[WriteFunc] = None,
88
- write_extension: Optional[str] = None,
89
- write_func_kwargs: Optional[dict[str, Any]] = None,
87
+ write_func: WriteFunc | None = None,
88
+ write_extension: str | None = None,
89
+ write_func_kwargs: dict[str, Any] | None = None,
90
90
  dirpath: Union[Path, str] = "predictions",
91
91
  ) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
92
92
  """
@@ -172,7 +172,7 @@ class PredictionWriterCallback(BasePredictionWriter):
172
172
  trainer: Trainer,
173
173
  pl_module: LightningModule,
174
174
  prediction: Any, # TODO: change to expected type
175
- batch_indices: Optional[Sequence[int]],
175
+ batch_indices: Sequence[int] | None,
176
176
  batch: Any, # TODO: change to expected type
177
177
  batch_idx: int,
178
178
  dataloader_idx: int,
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Sequence
4
4
  from pathlib import Path
5
- from typing import Any, Optional, Protocol, Union
5
+ from typing import Any, Protocol, Union
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
@@ -25,7 +25,7 @@ class WriteStrategy(Protocol):
25
25
  trainer: Trainer,
26
26
  pl_module: LightningModule,
27
27
  prediction: Any, # TODO: change to expected type
28
- batch_indices: Optional[Sequence[int]],
28
+ batch_indices: Sequence[int] | None,
29
29
  batch: Any, # TODO: change to expected type
30
30
  batch_idx: int,
31
31
  dataloader_idx: int,
@@ -133,7 +133,7 @@ class CacheTiles(WriteStrategy):
133
133
  trainer: Trainer,
134
134
  pl_module: LightningModule,
135
135
  prediction: tuple[NDArray, list[TileInformation]],
136
- batch_indices: Optional[Sequence[int]],
136
+ batch_indices: Sequence[int] | None,
137
137
  batch: tuple[NDArray, list[TileInformation]],
138
138
  batch_idx: int,
139
139
  dataloader_idx: int,
@@ -259,7 +259,7 @@ class WriteTilesZarr(WriteStrategy):
259
259
  trainer: Trainer,
260
260
  pl_module: LightningModule,
261
261
  prediction: Any,
262
- batch_indices: Optional[Sequence[int]],
262
+ batch_indices: Sequence[int] | None,
263
263
  batch: Any,
264
264
  batch_idx: int,
265
265
  dataloader_idx: int,
@@ -346,7 +346,7 @@ class WriteImage(WriteStrategy):
346
346
  trainer: Trainer,
347
347
  pl_module: LightningModule,
348
348
  prediction: NDArray,
349
- batch_indices: Optional[Sequence[int]],
349
+ batch_indices: Sequence[int] | None,
350
350
  batch: NDArray,
351
351
  batch_idx: int,
352
352
  dataloader_idx: int,
@@ -1,6 +1,6 @@
1
1
  """Module containing convenience function to create `WriteStrategy`."""
2
2
 
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  from careamics.config.support import SupportedData
6
6
  from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
@@ -11,9 +11,9 @@ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
11
11
  def create_write_strategy(
12
12
  write_type: SupportedWriteType,
13
13
  tiled: bool,
14
- write_func: Optional[WriteFunc] = None,
15
- write_extension: Optional[str] = None,
16
- write_func_kwargs: Optional[dict[str, Any]] = None,
14
+ write_func: WriteFunc | None = None,
15
+ write_extension: str | None = None,
16
+ write_func_kwargs: dict[str, Any] | None = None,
17
17
  ) -> WriteStrategy:
18
18
  """
19
19
  Create a write strategy from convenient parameters.
@@ -78,8 +78,8 @@ def create_write_strategy(
78
78
 
79
79
  def _create_tiled_write_strategy(
80
80
  write_type: SupportedWriteType,
81
- write_func: Optional[WriteFunc],
82
- write_extension: Optional[str],
81
+ write_func: WriteFunc | None,
82
+ write_extension: str | None,
83
83
  write_func_kwargs: dict[str, Any],
84
84
  ) -> WriteStrategy:
85
85
  """
@@ -130,7 +130,7 @@ def _create_tiled_write_strategy(
130
130
 
131
131
 
132
132
  def select_write_func(
133
- write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
133
+ write_type: SupportedWriteType, write_func: WriteFunc | None = None
134
134
  ) -> WriteFunc:
135
135
  """
136
136
  Return a function to write images.
@@ -177,7 +177,7 @@ def select_write_func(
177
177
 
178
178
 
179
179
  def select_write_extension(
180
- write_type: SupportedWriteType, write_extension: Optional[str] = None
180
+ write_type: SupportedWriteType, write_extension: str | None = None
181
181
  ) -> str:
182
182
  """
183
183
  Return an extension to add to file paths.