careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -124,7 +124,6 @@ class Configuration(BaseModel):
124
124
  >>> config_dict = {
125
125
  ... "experiment_name": "N2V_experiment",
126
126
  ... "algorithm_config": {
127
- ... "algorithm_type": "fcn",
128
127
  ... "algorithm": "n2v",
129
128
  ... "loss": "n2v",
130
129
  ... "model": {
@@ -158,7 +157,7 @@ class Configuration(BaseModel):
158
157
 
159
158
  # Sub-configurations
160
159
  algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
161
- discriminator="algorithm_type"
160
+ discriminator="algorithm"
162
161
  )
163
162
  """Algorithm configuration, holding all parameters required to configure the
164
163
  model."""
@@ -10,7 +10,6 @@ from numpy.typing import NDArray
10
10
  from pydantic import (
11
11
  BaseModel,
12
12
  ConfigDict,
13
- Discriminator,
14
13
  Field,
15
14
  PlainSerializer,
16
15
  field_validator,
@@ -19,9 +18,7 @@ from pydantic import (
19
18
  from typing_extensions import Annotated, Self
20
19
 
21
20
  from .support import SupportedTransform
22
- from .transformations.n2v_manipulate_model import N2VManipulateModel
23
- from .transformations.xy_flip_model import XYFlipModel
24
- from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
21
+ from .transformations import TRANSFORMS_UNION, N2VManipulateModel
25
22
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
26
23
 
27
24
 
@@ -48,17 +45,6 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
48
45
  """Annotated float type, used to serialize floats to strings."""
49
46
 
50
47
 
51
- TRANSFORMS_UNION = Annotated[
52
- Union[
53
- XYFlipModel,
54
- XYRandomRotate90Model,
55
- N2VManipulateModel,
56
- ],
57
- Discriminator("name"), # used to tell the different transform models apart
58
- ]
59
- """Available transforms in CAREamics."""
60
-
61
-
62
48
  class DataConfig(BaseModel):
63
49
  """
64
50
  Data configuration.
@@ -24,11 +24,11 @@ class FCNAlgorithmConfig(BaseModel):
24
24
 
25
25
  Attributes
26
26
  ----------
27
- algorithm : Literal["n2v", "custom"]
27
+ algorithm : {"n2v", "care", "n2n", "custom"}
28
28
  Algorithm to use.
29
- loss : Literal["n2v", "mae", "mse"]
29
+ loss : {"n2v", "mae", "mse"}
30
30
  Loss function to use.
31
- model : Union[UNetModel, LVAEModel, CustomModel]
31
+ model : UNetModel or CustomModel
32
32
  Model architecture to use.
33
33
  optimizer : OptimizerModel, optional
34
34
  Optimizer to use.
@@ -48,7 +48,6 @@ class FCNAlgorithmConfig(BaseModel):
48
48
  >>> from careamics.config import FCNAlgorithmConfig
49
49
  >>> config_dict = {
50
50
  ... "algorithm": "n2v",
51
- ... "algorithm_type": "fcn",
52
51
  ... "loss": "n2v",
53
52
  ... "model": {
54
53
  ... "architecture": "UNet",
@@ -65,11 +64,6 @@ class FCNAlgorithmConfig(BaseModel):
65
64
  )
66
65
 
67
66
  # Mandatory fields
68
- # defined in SupportedAlgorithm
69
- algorithm_type: Literal["fcn"]
70
- """Algorithm type must be `fcn` (fully convolutional network) to differentiate this
71
- configuration from LVAE."""
72
-
73
67
  algorithm: Literal["n2v", "care", "n2n", "custom"]
74
68
  """Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
75
69
  model architecture."""
@@ -145,3 +139,14 @@ class FCNAlgorithmConfig(BaseModel):
145
139
  Pretty string.
146
140
  """
147
141
  return pformat(self.model_dump())
142
+
143
+ @classmethod
144
+ def get_compatible_algorithms(cls) -> list[str]:
145
+ """Get the list of compatible algorithms.
146
+
147
+ Returns
148
+ -------
149
+ list of str
150
+ List of compatible algorithms.
151
+ """
152
+ return ["n2v", "care", "n2n"]
@@ -2,16 +2,30 @@
2
2
 
3
3
  from typing import Literal, Optional, Union
4
4
 
5
+ import numpy as np
5
6
  import torch
6
- from pydantic import BaseModel, ConfigDict
7
+ from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
8
+ from typing_extensions import Annotated
7
9
 
8
10
  from careamics.models.lvae.noise_models import (
9
11
  GaussianMixtureNoiseModel,
10
12
  MultiChannelNoiseModel,
11
13
  )
14
+ from careamics.utils.serializers import _array_to_json, _to_torch
12
15
 
13
16
  NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
14
17
 
18
+ # TODO: this is a temporary solution to serialize and deserialize tensor fields
19
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
20
+ # with such tensors to/from JSON files during, resp., training and evaluation.
21
+ Tensor = Annotated[
22
+ Union[np.ndarray, torch.Tensor],
23
+ PlainSerializer(_array_to_json, return_type=str),
24
+ PlainValidator(_to_torch),
25
+ ]
26
+ """Annotated tensor type, used to serialize arrays or tensors to JSON strings
27
+ and deserialize them back to tensors."""
28
+
15
29
 
16
30
  class GaussianLikelihoodConfig(BaseModel):
17
31
  """Gaussian likelihood configuration."""
@@ -31,13 +45,16 @@ class NMLikelihoodConfig(BaseModel):
31
45
 
32
46
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
33
47
 
34
- data_mean: Union[torch.Tensor] = torch.zeros(1)
48
+ # TODO remove and use as parameters to the likelihood functions?
49
+ data_mean: Tensor = torch.zeros(1)
35
50
  """The mean of the data, used to unnormalize data for noise model evaluation.
36
51
  Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
37
52
 
38
- data_std: Union[torch.Tensor] = torch.ones(1)
53
+ # TODO remove and use as parameters to the likelihood functions?
54
+ data_std: Tensor = torch.ones(1)
39
55
  """The standard deviation of the data, used to unnormalize data for noise
40
56
  model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
41
57
 
42
- noise_model: Union[NoiseModel, None] = None
58
+ # TODO: serialization/deserialization for this
59
+ noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
43
60
  """The noise model instance used to compute the likelihood."""
@@ -4,8 +4,30 @@ from pathlib import Path
4
4
  from typing import Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
- from pydantic import BaseModel, ConfigDict, Field, model_validator
8
- from typing_extensions import Self
7
+ import torch
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ PlainSerializer,
13
+ PlainValidator,
14
+ model_validator,
15
+ )
16
+ from typing_extensions import Annotated, Self
17
+
18
+ from careamics.utils.serializers import _array_to_json, _to_numpy
19
+
20
+ # TODO: this is a temporary solution to serialize and deserialize array fields
21
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
22
+ # with such arrays to/from JSON files during, resp., training and evaluation.
23
+ Array = Annotated[
24
+ Union[np.ndarray, torch.Tensor],
25
+ PlainSerializer(_array_to_json, return_type=str),
26
+ PlainValidator(_to_numpy),
27
+ ]
28
+ """Annotated array type, used to serialize arrays or tensors to JSON strings
29
+ and deserialize them back to arrays."""
30
+
9
31
 
10
32
  # TODO: add histogram-based noise model
11
33
 
@@ -26,13 +48,17 @@ class GaussianMixtureNMConfig(BaseModel):
26
48
  """Path to the directory where the trained noise model (*.npz) is saved in the
27
49
  `train` method."""
28
50
 
29
- signal: Optional[Union[str, Path, np.ndarray]] = None
51
+ # TODO remove and use as parameters to the NM functions?
52
+ signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
30
53
  """Path to the file containing signal or respective numpy array."""
31
54
 
32
- observation: Optional[Union[str, Path, np.ndarray]] = None
55
+ # TODO remove and use as parameters to the NM functions?
56
+ observation: Optional[Union[str, Path, np.ndarray]] = Field(
57
+ default=None, exclude=True
58
+ )
33
59
  """Path to the file containing observation or respective numpy array."""
34
60
 
35
- weight: Optional[np.ndarray] = None
61
+ weight: Optional[Array] = None
36
62
  """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
37
63
  describing the GMM noise model, with each row corresponding to one
38
64
  parameter of each gaussian, namely [mean, standard deviation and weight].
@@ -44,7 +44,9 @@ class OptimizerModel(BaseModel):
44
44
  )
45
45
 
46
46
  # Mandatory field
47
- name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
47
+ name: Literal["Adam", "SGD", "Adamax"] = Field(
48
+ default="Adam", validate_default=True
49
+ )
48
50
  """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
49
51
 
50
52
  # Optional parameters, empty dict default value to allow filtering dictionary
@@ -19,7 +19,7 @@ class SupportedOptimizer(str, BaseEnum):
19
19
  # Adagrad = "Adagrad"
20
20
  ADAM = "Adam"
21
21
  # AdamW = "AdamW"
22
- # Adamax = "Adamax"
22
+ ADAMAX = "Adamax"
23
23
  # LBFGS = "LBFGS"
24
24
  # NAdam = "NAdam"
25
25
  # RAdam = "RAdam"
@@ -9,3 +9,4 @@ class SupportedTransform(str, BaseEnum):
9
9
  XY_FLIP = "XYFlip"
10
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
11
  N2V_MANIPULATE = "N2VManipulate"
12
+ NORMALIZE = "Normalize"
@@ -3,13 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional
6
+ from typing import Literal, Optional, Union
7
7
 
8
- from pydantic import (
9
- BaseModel,
10
- ConfigDict,
11
- Field,
12
- )
8
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
13
9
 
14
10
  from .callback_model import CheckpointModel, EarlyStoppingModel
15
11
 
@@ -37,6 +33,20 @@ class TrainingConfig(BaseModel):
37
33
  num_epochs: int = Field(default=20, ge=1)
38
34
  """Number of epochs, greater than 0."""
39
35
 
36
+ precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
37
+ """Numerical precision"""
38
+ max_steps: int = Field(default=-1, ge=-1)
39
+ """Maximum number of steps to train for. -1 means no limit."""
40
+ check_val_every_n_epoch: int = Field(default=1, ge=1)
41
+ """Validation step frequency."""
42
+ enable_progress_bar: bool = Field(default=True)
43
+ """Whether to enable the progress bar."""
44
+ accumulate_grad_batches: int = Field(default=1, ge=1)
45
+ """Number of batches to accumulate gradients over before stepping the optimizer."""
46
+ gradient_clip_val: Optional[Union[int, float]] = None
47
+ """The value to which to clip the gradient"""
48
+ gradient_clip_algorithm: Literal["value", "norm"] = "norm"
49
+ """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
40
50
  logger: Optional[Literal["wandb", "tensorboard"]] = None
41
51
  """Logger to use during training. If None, no logger will be used. Available
42
52
  loggers are defined in SupportedLogger."""
@@ -70,3 +80,22 @@ class TrainingConfig(BaseModel):
70
80
  Whether the logger is defined or not.
71
81
  """
72
82
  return self.logger is not None
83
+
84
+ @field_validator("max_steps")
85
+ @classmethod
86
+ def validate_max_steps(cls, max_steps: int) -> int:
87
+ """Validate the max_steps parameter.
88
+
89
+ Parameters
90
+ ----------
91
+ max_steps : int
92
+ Maximum number of steps to train for. -1 means no limit.
93
+
94
+ Returns
95
+ -------
96
+ int
97
+ Validated max_steps.
98
+ """
99
+ if max_steps == 0:
100
+ raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
101
+ return max_steps
@@ -5,11 +5,14 @@ __all__ = [
5
5
  "XYFlipModel",
6
6
  "NormalizeModel",
7
7
  "XYRandomRotate90Model",
8
- "XorYFlipModel",
8
+ "TransformModel",
9
+ "TRANSFORMS_UNION",
9
10
  ]
10
11
 
11
12
 
12
13
  from .n2v_manipulate_model import N2VManipulateModel
13
14
  from .normalize_model import NormalizeModel
15
+ from .transform_model import TransformModel
16
+ from .transform_union import TRANSFORMS_UNION
14
17
  from .xy_flip_model import XYFlipModel
15
18
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -0,0 +1,20 @@
1
+ """Type used to represent all transformations users can create."""
2
+
3
+ from typing import Union
4
+
5
+ from pydantic import Discriminator
6
+ from typing_extensions import Annotated
7
+
8
+ from .n2v_manipulate_model import N2VManipulateModel
9
+ from .xy_flip_model import XYFlipModel
10
+ from .xy_random_rotate90_model import XYRandomRotate90Model
11
+
12
+ TRANSFORMS_UNION = Annotated[
13
+ Union[
14
+ XYFlipModel,
15
+ XYRandomRotate90Model,
16
+ N2VManipulateModel,
17
+ ],
18
+ Discriminator("name"), # used to tell the different transform models apart
19
+ ]
20
+ """Available transforms in CAREamics."""
@@ -19,40 +19,7 @@ from .optimizer_models import LrSchedulerModel, OptimizerModel
19
19
  class VAEAlgorithmConfig(BaseModel):
20
20
  """Algorithm configuration.
21
21
 
22
- This Pydantic model validates the parameters governing the components of the
23
- training algorithm: which algorithm, loss function, model architecture, optimizer,
24
- and learning rate scheduler to use.
25
-
26
- Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
27
- only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
28
- allows you to register your own architecture and select it using its name as
29
- `name` in the custom pydantic model.
30
-
31
- Attributes
32
- ----------
33
- algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
34
- Algorithm to use.
35
- loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
36
- Loss function to use.
37
- model : Union[LVAEModel, CustomModel]
38
- Model architecture to use.
39
- noise_model: Optional[MultiChannelNmModel]
40
- Noise model to use.
41
- noise_model_likelihood_model: Optional[NMLikelihoodModel]
42
- Noise model likelihood model to use.
43
- gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
44
- Gaussian likelihood model to use.
45
- optimizer : OptimizerModel, optional
46
- Optimizer to use.
47
- lr_scheduler : LrSchedulerModel, optional
48
- Learning rate scheduler to use.
49
-
50
- Raises
51
- ------
52
- ValueError
53
- Algorithm parameter type validation errors.
54
- ValueError
55
- If the algorithm, loss and model are not compatible.
22
+ # TODO
56
23
 
57
24
  Examples
58
25
  --------
@@ -70,8 +37,7 @@ class VAEAlgorithmConfig(BaseModel):
70
37
  # defined in SupportedAlgorithm
71
38
  # TODO: Use supported Enum classes for typing?
72
39
  # - values can still be passed as strings and they will be cast to Enum
73
- algorithm_type: Literal["vae"]
74
- algorithm: Literal["musplit", "denoisplit", "custom"]
40
+ algorithm: Literal["musplit", "denoisplit"]
75
41
  loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
76
42
  model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
77
43
 
@@ -8,6 +8,7 @@ import numpy as np
8
8
  from numpy.typing import NDArray
9
9
 
10
10
  from careamics.config.tile_information import TileInformation
11
+ from careamics.lvae_training.dataset.utils.index_manager import GridIndexManager
11
12
 
12
13
 
13
14
  def extract_tiles(
@@ -66,10 +67,12 @@ def extract_tiles(
66
67
  # itertools.product is equivalent of nested loops
67
68
 
68
69
  stitch_size = tile_size - overlaps
69
- for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]):
70
+ for tile_grid_indices in itertools.product(
71
+ *[range(n) for n in tile_grid_shape]
72
+ ):
70
73
 
71
74
  # calculate crop coordinates
72
- crop_coords_start = np.array(tile_grid_coords) * stitch_size
75
+ crop_coords_start = np.array(tile_grid_indices) * stitch_size
73
76
  crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
74
77
  ...,
75
78
  *[
@@ -80,7 +83,7 @@ def extract_tiles(
80
83
  tile = sample[crop_slices]
81
84
 
82
85
  tile_info = compute_tile_info(
83
- np.array(tile_grid_coords),
86
+ np.array(tile_grid_indices),
84
87
  np.array(data_shape),
85
88
  np.array(tile_size),
86
89
  np.array(overlaps),
@@ -93,19 +96,98 @@ def extract_tiles(
93
96
  yield tile, tile_info
94
97
 
95
98
 
99
+ def compute_tile_info_legacy(
100
+ grid_index_manager: GridIndexManager, index: int
101
+ ) -> TileInformation:
102
+ """
103
+ Compute the tile information for a tile at a given dataset index.
104
+
105
+ Parameters
106
+ ----------
107
+ grid_index_manager : GridIndexManager
108
+ The grid index manager that keeps track of tile locations.
109
+ index : int
110
+ The dataset index.
111
+
112
+ Returns
113
+ -------
114
+ TileInformation
115
+ Information that describes how to crop and stitch a tile to create a full image.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If `grid_index_manager.data_shape` does not have 4 or 5 dimensions.
121
+ """
122
+ data_shape = np.array(grid_index_manager.data_shape)
123
+ if len(data_shape) == 5:
124
+ n_spatial_dims = 3
125
+ elif len(data_shape) == 4:
126
+ n_spatial_dims = 2
127
+ else:
128
+ raise ValueError("Data shape must have 4 or 5 dimensions, equating to SC(Z)YX.")
129
+
130
+ stitch_coords_start = np.array(
131
+ grid_index_manager.get_location_from_dataset_idx(index)
132
+ )
133
+ stitch_coords_end = stitch_coords_start + np.array(grid_index_manager.grid_shape)
134
+
135
+ tile_coords_start = stitch_coords_start - grid_index_manager.patch_offset()
136
+
137
+ # --- replace out of bounds indices
138
+ out_of_lower_bound = stitch_coords_start < 0
139
+ out_of_upper_bound = stitch_coords_end > data_shape
140
+ stitch_coords_start[out_of_lower_bound] = 0
141
+ stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
142
+
143
+ # TODO: TilingMode not in current version
144
+ # if grid_index_manager.tiling_mode == TilingMode.ShiftBoundary:
145
+ # for dim in range(len(stitch_coords_start)):
146
+ # if tile_coords_start[dim] == 0:
147
+ # stitch_coords_start[dim] = 0
148
+ # if tile_coords_end[dim] == grid_index_manager.data_shape[dim]:
149
+ # tile_coords_end [dim]= grid_index_manager.data_shape[dim]
150
+
151
+ # --- calculate overlap crop coords
152
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
153
+ overlap_crop_coords_end = overlap_crop_coords_start + (
154
+ stitch_coords_end - stitch_coords_start
155
+ )
156
+
157
+ last_tile = index == grid_index_manager.total_grid_count() - 1
158
+
159
+ # --- combine start and end
160
+ stitch_coords = tuple(
161
+ (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
162
+ )
163
+ overlap_crop_coords = tuple(
164
+ (start, end)
165
+ for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
166
+ )
167
+
168
+ tile_info = TileInformation(
169
+ array_shape=data_shape[1:], # remove S dim
170
+ last_tile=last_tile,
171
+ overlap_crop_coords=overlap_crop_coords[-n_spatial_dims:],
172
+ stitch_coords=stitch_coords[-n_spatial_dims:],
173
+ sample_id=0,
174
+ )
175
+ return tile_info
176
+
177
+
96
178
  def compute_tile_info(
97
- tile_grid_coords: NDArray[np.int_],
179
+ tile_grid_indices: NDArray[np.int_],
98
180
  data_shape: NDArray[np.int_],
99
181
  tile_size: NDArray[np.int_],
100
182
  overlaps: NDArray[np.int_],
101
183
  sample_id: int = 0,
102
184
  ) -> TileInformation:
103
185
  """
104
- Compute the tile information for a tile with the coordinates `tile_grid_coords`.
186
+ Compute the tile information for a tile with the coordinates `tile_grid_indices`.
105
187
 
106
188
  Parameters
107
189
  ----------
108
- tile_grid_coords : 1D np.array of int
190
+ tile_grid_indices : 1D np.array of int
109
191
  The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
110
192
  tiling the coordinates for the second tile in the first row of tiles would be
111
193
  (0, 1).
@@ -127,7 +209,7 @@ def compute_tile_info(
127
209
 
128
210
  # The extent of the tile which will make up part of the stitched image.
129
211
  stitch_size = tile_size - overlaps
130
- stitch_coords_start = tile_grid_coords * stitch_size
212
+ stitch_coords_start = tile_grid_indices * stitch_size
131
213
  stitch_coords_end = stitch_coords_start + stitch_size
132
214
 
133
215
  tile_coords_start = stitch_coords_start - overlaps // 2
@@ -155,7 +237,7 @@ def compute_tile_info(
155
237
 
156
238
  # --- Check if last tile
157
239
  tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
158
- last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all()
240
+ last_tile = (tile_grid_indices == (tile_grid_shape - 1)).all()
159
241
 
160
242
  tile_info = TileInformation(
161
243
  array_shape=data_shape,
@@ -1,6 +1,6 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Callable, Literal, Optional, Union
3
+ from typing import Any, Callable, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import pytorch_lightning as L
@@ -271,6 +271,12 @@ class VAEModule(L.LightningModule):
271
271
  self.noise_model: NoiseModel = noise_model_factory(
272
272
  self.algorithm_config.noise_model
273
273
  )
274
+ # TODO: here we can add some code to check whether the noise model is not None
275
+ # and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
276
+ # instead, None. In that case we could assign the noise model to the latter.
277
+ # This is particular useful when loading an algorithm config from file.
278
+ # Indeed, in that case the noise model in the nm likelihood is likely
279
+ # not available since excluded from serializaion.
274
280
  self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
275
281
  self.algorithm_config.noise_model_likelihood_model
276
282
  )
@@ -550,7 +556,6 @@ class VAEModule(L.LightningModule):
550
556
 
551
557
  # TODO: make this LVAE compatible (?)
552
558
  def create_careamics_module(
553
- algorithm_type: Literal["fcn"],
554
559
  algorithm: Union[SupportedAlgorithm, str],
555
560
  loss: Union[SupportedLoss, str],
556
561
  architecture: Union[SupportedArchitecture, str],
@@ -567,8 +572,6 @@ def create_careamics_module(
567
572
 
568
573
  Parameters
569
574
  ----------
570
- algorithm_type : Literal["fcn"]
571
- Algorithm type to use for training.
572
575
  algorithm : SupportedAlgorithm or str
573
576
  Algorithm to use for training (see SupportedAlgorithm).
574
577
  loss : SupportedLoss or str
@@ -604,7 +607,6 @@ def create_careamics_module(
604
607
  if model_parameters is None:
605
608
  model_parameters = {}
606
609
  algorithm_configuration: dict[str, Any] = {
607
- "algorithm_type": algorithm_type,
608
610
  "algorithm": algorithm,
609
611
  "loss": loss,
610
612
  "optimizer": {
@@ -623,10 +625,10 @@ def create_careamics_module(
623
625
  algorithm_configuration["model"] = model_configuration
624
626
 
625
627
  # call the parent init using an AlgorithmModel instance
626
- if algorithm_configuration["algorithm_type"] == "fcn":
628
+ algorithm_str = algorithm_configuration["algorithm"]
629
+ if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
627
630
  return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
628
631
  else:
629
632
  raise NotImplementedError(
630
- f"Model {algorithm_configuration['model']['architecture']} is not"
631
- f"implemented or unknown."
633
+ f"Model {algorithm_str} is not implemented or unknown."
632
634
  )
@@ -9,8 +9,8 @@ from numpy.typing import NDArray
9
9
  from torch.utils.data import DataLoader
10
10
 
11
11
  from careamics.config import DataConfig
12
- from careamics.config.data_model import TRANSFORMS_UNION
13
12
  from careamics.config.support import SupportedData
13
+ from careamics.config.transformations import TransformModel
14
14
  from careamics.dataset.dataset_utils import (
15
15
  get_files_size,
16
16
  list_files,
@@ -472,7 +472,7 @@ def create_train_datamodule(
472
472
  axes: str,
473
473
  batch_size: int,
474
474
  val_data: Optional[Union[str, Path, NDArray]] = None,
475
- transforms: Optional[list[TRANSFORMS_UNION]] = None,
475
+ transforms: Optional[list[TransformModel]] = None,
476
476
  train_target_data: Optional[Union[str, Path, NDArray]] = None,
477
477
  val_target_data: Optional[Union[str, Path, NDArray]] = None,
478
478
  read_source_func: Optional[Callable] = None,
@@ -56,9 +56,9 @@ class LVAELossParameters:
56
56
  reconstruction_weight: float = 1.0
57
57
  """Weight for the reconstruction loss in the total net loss
58
58
  (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
- musplit_weight: float = 0.0
60
- """Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
61
- denoisplit_weight: float = 1.0
59
+ musplit_weight: float = 0.1
60
+ """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
61
+ denoisplit_weight: float = 0.9
62
62
  """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
63
  kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
64
  """Type of KL divergence used as KL loss."""
@@ -137,8 +137,8 @@ def reconstruction_loss_musplit_denoisplit(
137
137
  recons_loss : torch.Tensor
138
138
  The reconstruction loss. Shape is (1, ).
139
139
  """
140
- # TODO: is this safe to check for predict_logvar value?
141
- # otherwise use `gaussian_likelihood.predict_logvar` (or both)
140
+ # TODO: refactor this function to make it closer to `get_reconstruction_loss`
141
+ # (or viceversa)
142
142
  if predictions.shape[1] == 2 * targets.shape[1]:
143
143
  # predictions contain both mean and log-variance
144
144
  out_mean, _ = predictions.chunk(2, dim=1)
@@ -0,0 +1,15 @@
1
+ from .multich_dataset import MultiChDloader
2
+ from .lc_dataset import LCMultiChDloader
3
+ from .multifile_dataset import MultiFileDset
4
+ from .config import DatasetConfig
5
+ from .types import DataType, DataSplitType, TilingMode
6
+
7
+ __all__ = [
8
+ "DatasetConfig",
9
+ "MultiChDloader",
10
+ "LCMultiChDloader",
11
+ "MultiFileDset",
12
+ "DataType",
13
+ "DataSplitType",
14
+ "TilingMode",
15
+ ]