careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -1,7 +1,8 @@
1
1
  """A class to train, predict and export models in CAREamics."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union, overload
5
+ from typing import Any, Literal, Optional, Union, overload
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
@@ -52,6 +53,9 @@ class CAREamist:
52
53
  by default None.
53
54
  callbacks : list of Callback, optional
54
55
  List of callbacks to use during training and prediction, by default None.
56
+ enable_progress_bar : bool
57
+ Whether a progress bar will be displayed during training, validation and
58
+ prediction.
55
59
 
56
60
  Attributes
57
61
  ----------
@@ -77,6 +81,7 @@ class CAREamist:
77
81
  source: Union[Path, str],
78
82
  work_dir: Optional[Union[Path, str]] = None,
79
83
  callbacks: Optional[list[Callback]] = None,
84
+ enable_progress_bar: bool = True,
80
85
  ) -> None: ...
81
86
 
82
87
  @overload
@@ -85,6 +90,7 @@ class CAREamist:
85
90
  source: Configuration,
86
91
  work_dir: Optional[Union[Path, str]] = None,
87
92
  callbacks: Optional[list[Callback]] = None,
93
+ enable_progress_bar: bool = True,
88
94
  ) -> None: ...
89
95
 
90
96
  def __init__(
@@ -92,6 +98,7 @@ class CAREamist:
92
98
  source: Union[Path, str, Configuration],
93
99
  work_dir: Optional[Union[Path, str]] = None,
94
100
  callbacks: Optional[list[Callback]] = None,
101
+ enable_progress_bar: bool = True,
95
102
  ) -> None:
96
103
  """
97
104
  Initialize CAREamist with a configuration object or a path.
@@ -112,6 +119,9 @@ class CAREamist:
112
119
  by default None.
113
120
  callbacks : list of Callback, optional
114
121
  List of callbacks to use during training and prediction, by default None.
122
+ enable_progress_bar : bool
123
+ Whether a progress bar will be displayed during training, validation and
124
+ prediction.
115
125
 
116
126
  Raises
117
127
  ------
@@ -169,7 +179,7 @@ class CAREamist:
169
179
  self.model, self.cfg = load_pretrained(source)
170
180
 
171
181
  # define the checkpoint saving callback
172
- self._define_callbacks(callbacks)
182
+ self._define_callbacks(callbacks, enable_progress_bar)
173
183
 
174
184
  # instantiate logger
175
185
  csv_logger = CSVLogger(
@@ -202,7 +212,7 @@ class CAREamist:
202
212
  precision=self.cfg.training_config.precision,
203
213
  max_steps=self.cfg.training_config.max_steps,
204
214
  check_val_every_n_epoch=self.cfg.training_config.check_val_every_n_epoch,
205
- enable_progress_bar=self.cfg.training_config.enable_progress_bar,
215
+ enable_progress_bar=enable_progress_bar,
206
216
  accumulate_grad_batches=self.cfg.training_config.accumulate_grad_batches,
207
217
  gradient_clip_val=self.cfg.training_config.gradient_clip_val,
208
218
  gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
@@ -215,13 +225,19 @@ class CAREamist:
215
225
  self.train_datamodule: Optional[TrainDataModule] = None
216
226
  self.pred_datamodule: Optional[PredictDataModule] = None
217
227
 
218
- def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
228
+ def _define_callbacks(
229
+ self, callbacks: Optional[list[Callback]], enable_progress_bar: bool
230
+ ) -> None:
219
231
  """Define the callbacks for the training loop.
220
232
 
221
233
  Parameters
222
234
  ----------
223
235
  callbacks : list of Callback, optional
224
236
  List of callbacks to use during training and prediction, by default None.
237
+ enable_progress_bar : bool
238
+ Whether a progress bar will be displayed during training, validation and
239
+ prediction. It controls whether a `ProgressBarCallback` is added to the
240
+ callback list.
225
241
  """
226
242
  self.callbacks = [] if callbacks is None else callbacks
227
243
 
@@ -251,9 +267,10 @@ class CAREamist:
251
267
  filename=self.cfg.experiment_name,
252
268
  **self.cfg.training_config.checkpoint_callback.model_dump(),
253
269
  ),
254
- ProgressBarCallback(),
255
270
  ]
256
271
  )
272
+ if enable_progress_bar:
273
+ self.callbacks.append(ProgressBarCallback())
257
274
 
258
275
  # early stopping callback
259
276
  if self.cfg.training_config.early_stopping_callback is not None:
@@ -811,7 +828,7 @@ class CAREamist:
811
828
  source_path = source.pred_data
812
829
  source_data_type = source.data_type
813
830
  extension_filter = source.extension_filter
814
- elif isinstance(source, (str, Path)):
831
+ elif isinstance(source, str | Path):
815
832
  source_path = source
816
833
  source_data_type = data_type or self.cfg.data_config.data_type
817
834
  extension_filter = SupportedData.get_extension_pattern(
@@ -824,7 +841,7 @@ class CAREamist:
824
841
  raise ValueError(
825
842
  "Predicting to disk is not supported for input type 'array'."
826
843
  )
827
- assert isinstance(source_path, (Path, str)) # because data_type != "array"
844
+ assert isinstance(source_path, str | Path) # because data_type != "array"
828
845
  source_path = Path(source_path)
829
846
 
830
847
  file_paths = list_files(source_path, source_data_type, extension_filter)
careamics/cli/utils.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional
4
4
 
5
5
 
6
6
  def handle_2D_3D_callback(
7
- value: Optional[tuple[int, int, int]]
7
+ value: Optional[tuple[int, int, int]],
8
8
  ) -> Optional[tuple[int, ...]]:
9
9
  """
10
10
  Callback for options that require 2D or 3D inputs.
@@ -1,4 +1,4 @@
1
- """"N2V Algorithm configuration."""
1
+ """N2V Algorithm configuration."""
2
2
 
3
3
  from typing import Annotated, Literal
4
4
 
@@ -63,6 +63,9 @@ class UNetModel(ArchitectureModel):
63
63
  """Whether information is processed independently in each channel, used to train
64
64
  channels independently."""
65
65
 
66
+ use_batch_norm: bool = Field(default=True, validate_default=True)
67
+ """Whether to use batch normalization in the model."""
68
+
66
69
  @field_validator("num_channels_init")
67
70
  @classmethod
68
71
  def validate_num_channels_init(cls, num_channels_init: int) -> int:
@@ -22,52 +22,42 @@ class CheckpointModel(BaseModel):
22
22
  https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
23
  """
24
24
 
25
- model_config = ConfigDict(
26
- validate_assignment=True,
27
- )
25
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
28
26
 
29
- monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
- """Quantity to monitor."""
27
+ monitor: Literal["val_loss"] = Field(default="val_loss")
28
+ """Quantity to monitor, currently only `val_loss`."""
31
29
 
32
- verbose: bool = Field(default=False, validate_default=True)
30
+ verbose: bool = Field(default=False)
33
31
  """Verbosity mode."""
34
32
 
35
- save_weights_only: bool = Field(default=False, validate_default=True)
33
+ save_weights_only: bool = Field(default=False)
36
34
  """When `True`, only the model's weights will be saved (model.save_weights)."""
37
35
 
38
- save_last: Optional[Literal[True, False, "link"]] = Field(
39
- default=True, validate_default=True
40
- )
36
+ save_last: Optional[Literal[True, False, "link"]] = Field(default=True)
41
37
  """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
38
 
43
- save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
39
+ save_top_k: int = Field(default=3, ge=-1, le=100)
44
40
  """If `save_top_k == kz, the best k models according to the quantity monitored
45
41
  will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
42
  all models are saved."""
47
43
 
48
- mode: Literal["min", "max"] = Field(default="min", validate_default=True)
44
+ mode: Literal["min", "max"] = Field(default="min")
49
45
  """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
46
  save file is made based on either the maximization or the minimization of the
51
47
  monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
48
  be 'min', etc.
53
49
  """
54
50
 
55
- auto_insert_metric_name: bool = Field(default=False, validate_default=True)
51
+ auto_insert_metric_name: bool = Field(default=False)
56
52
  """When `True`, the checkpoints filenames will contain the metric name."""
57
53
 
58
- every_n_train_steps: Optional[int] = Field(
59
- default=None, ge=1, le=10, validate_default=True
60
- )
54
+ every_n_train_steps: Optional[int] = Field(default=None, ge=1, le=1000)
61
55
  """Number of training steps between checkpoints."""
62
56
 
63
- train_time_interval: Optional[timedelta] = Field(
64
- default=None, validate_default=True
65
- )
57
+ train_time_interval: Optional[timedelta] = Field(default=None)
66
58
  """Checkpoints are monitored at the specified time interval."""
67
59
 
68
- every_n_epochs: Optional[int] = Field(
69
- default=None, ge=1, le=10, validate_default=True
70
- )
60
+ every_n_epochs: Optional[int] = Field(default=None, ge=1, le=100)
71
61
  """Number of epochs between checkpoints."""
72
62
 
73
63
 
@@ -83,41 +73,40 @@ class EarlyStoppingModel(BaseModel):
83
73
 
84
74
  model_config = ConfigDict(
85
75
  validate_assignment=True,
76
+ validate_default=True,
86
77
  )
87
78
 
88
- monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
79
+ monitor: Literal["val_loss"] = Field(default="val_loss")
89
80
  """Quantity to monitor."""
90
81
 
91
- min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
82
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
92
83
  """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
84
  absolute change of less than or equal to min_delta, will count as no improvement."""
94
85
 
95
- patience: int = Field(default=3, ge=1, le=10, validate_default=True)
86
+ patience: int = Field(default=3, ge=1, le=10)
96
87
  """Number of checks with no improvement after which training will be stopped."""
97
88
 
98
- verbose: bool = Field(default=False, validate_default=True)
89
+ verbose: bool = Field(default=False)
99
90
  """Verbosity mode."""
100
91
 
101
- mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
92
+ mode: Literal["min", "max", "auto"] = Field(default="min")
102
93
  """One of {min, max, auto}."""
103
94
 
104
- check_finite: bool = Field(default=True, validate_default=True)
95
+ check_finite: bool = Field(default=True)
105
96
  """When `True`, stops training when the monitored quantity becomes `NaN` or
106
97
  `inf`."""
107
98
 
108
- stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
99
+ stopping_threshold: Optional[float] = Field(default=None)
109
100
  """Stop training immediately once the monitored quantity reaches this threshold."""
110
101
 
111
- divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
102
+ divergence_threshold: Optional[float] = Field(default=None)
112
103
  """Stop training as soon as the monitored quantity becomes worse than this
113
104
  threshold."""
114
105
 
115
- check_on_train_epoch_end: Optional[bool] = Field(
116
- default=False, validate_default=True
117
- )
106
+ check_on_train_epoch_end: Optional[bool] = Field(default=False)
118
107
  """Whether to run early stopping at the end of the training epoch. If this is
119
108
  `False`, then the check runs at the end of the validation."""
120
109
 
121
- log_rank_zero_only: bool = Field(default=False, validate_default=True)
110
+ log_rank_zero_only: bool = Field(default=False)
122
111
  """When set `True`, logs the status of the early stopping callback only for rank 0
123
112
  process."""
@@ -3,11 +3,14 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import re
6
+ from collections.abc import Callable
6
7
  from pprint import pformat
7
8
  from typing import Any, Literal, Union
8
9
 
10
+ import numpy as np
9
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
10
12
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
+ from pydantic.main import IncEx
11
14
  from typing_extensions import Self
12
15
 
13
16
  from careamics.config.algorithms import (
@@ -182,6 +185,50 @@ class Configuration(BaseModel):
182
185
 
183
186
  return name
184
187
 
188
+ @model_validator(mode="after")
189
+ def validate_n2v_mask_pixel_perc(self: Self) -> Self:
190
+ """
191
+ Validate that there will always be at least one blind-spot pixel in every patch.
192
+
193
+ The probability of creating a blind-spot pixel is a function of the chosen
194
+ masked pixel percentage and patch size.
195
+
196
+ Returns
197
+ -------
198
+ Self
199
+ Validated configuration.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If the probability of masking a pixel within a patch is less than 1 for the
205
+ chosen masked pixel percentage and patch size.
206
+ """
207
+ # No validation needed for non n2v algorithms
208
+ if not isinstance(self.algorithm_config, N2VAlgorithm):
209
+ return self
210
+
211
+ mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
212
+ patch_size = self.data_config.patch_size
213
+ expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
214
+
215
+ n_dims = 3 if self.algorithm_config.model.is_3D() else 2
216
+ patch_size_lower_bound = int(np.ceil(expected_area_per_pixel ** (1 / n_dims)))
217
+ required_patch_size = tuple(
218
+ 2 ** int(np.ceil(np.log2(patch_size_lower_bound))) for _ in range(n_dims)
219
+ )
220
+ required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
221
+ if expected_area_per_pixel > np.prod(patch_size):
222
+ raise ValueError(
223
+ "The probability of creating a blind-spot pixel within a patch is "
224
+ f"below 1, for a patch size of {patch_size} with a masked pixel "
225
+ f"percentage of {mask_pixel_perc}%. Either increase the patch size to "
226
+ f"{required_patch_size} or increase the masked pixel percentage to "
227
+ f"at least {required_mask_pixel_perc}%."
228
+ )
229
+
230
+ return self
231
+
185
232
  @model_validator(mode="after")
186
233
  def validate_3D(self: Self) -> Self:
187
234
  """
@@ -297,17 +344,18 @@ class Configuration(BaseModel):
297
344
  self,
298
345
  *,
299
346
  mode: Literal["json", "python"] | str = "python",
300
- include: Any | None = None,
301
- exclude: Any | None = None,
347
+ include: IncEx | None = None,
348
+ exclude: IncEx | None = None,
302
349
  context: Any | None = None,
303
- by_alias: bool = False,
350
+ by_alias: bool | None = False,
304
351
  exclude_unset: bool = False,
305
352
  exclude_defaults: bool = False,
306
353
  exclude_none: bool = True,
307
354
  round_trip: bool = False,
308
355
  warnings: bool | Literal["none", "warn", "error"] = True,
356
+ fallback: Callable[[Any], Any] | None = None,
309
357
  serialize_as_any: bool = False,
310
- ) -> dict:
358
+ ) -> dict[str, Any]:
311
359
  """
312
360
  Override model_dump method in order to set default values.
313
361
 
@@ -337,6 +385,8 @@ class Configuration(BaseModel):
337
385
  representation.
338
386
  warnings : bool | Literal['none', 'warn', 'error'], default=True
339
387
  Whether to emit warnings.
388
+ fallback : Callable[[Any], Any] | None, default=None
389
+ A function to call when an unknown value is encountered.
340
390
  serialize_as_any : bool, default=False
341
391
  Whether to serialize all types as Any.
342
392
 
@@ -356,6 +406,7 @@ class Configuration(BaseModel):
356
406
  exclude_none=exclude_none,
357
407
  round_trip=round_trip,
358
408
  warnings=warnings,
409
+ fallback=fallback,
359
410
  serialize_as_any=serialize_as_any,
360
411
  )
361
412