careamics 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (73) hide show
  1. careamics/careamist.py +4 -3
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +47 -1
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +3 -3
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/support/supported_data.py +7 -0
  20. careamics/config/support/supported_patching_strategies.py +22 -0
  21. careamics/config/validators/validator_utils.py +4 -3
  22. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  23. careamics/dataset/in_memory_dataset.py +2 -1
  24. careamics/dataset/iterable_dataset.py +2 -2
  25. careamics/dataset/iterable_pred_dataset.py +2 -2
  26. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  27. careamics/dataset/patching/patching.py +3 -2
  28. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  29. careamics/dataset/tiling/tiled_patching.py +2 -1
  30. careamics/dataset_ng/dataset.py +46 -50
  31. careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
  32. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
  33. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
  34. careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
  35. careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
  36. careamics/dataset_ng/factory.py +58 -15
  37. careamics/dataset_ng/legacy_interoperability.py +3 -1
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
  39. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
  40. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  41. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  42. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
  43. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  44. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  45. careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
  46. careamics/file_io/read/get_func.py +2 -1
  47. careamics/lightning/dataset_ng/__init__.py +1 -0
  48. careamics/lightning/dataset_ng/data_module.py +218 -28
  49. careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
  50. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
  51. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
  52. careamics/lightning/lightning_module.py +2 -1
  53. careamics/lightning/predict_data_module.py +2 -1
  54. careamics/lightning/train_data_module.py +2 -1
  55. careamics/losses/loss_factory.py +2 -1
  56. careamics/lvae_training/dataset/multicrop_dset.py +1 -1
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +1 -1
  59. careamics/model_io/bmz_io.py +1 -1
  60. careamics/model_io/model_io_utils.py +2 -2
  61. careamics/models/activation.py +2 -1
  62. careamics/prediction_utils/prediction_outputs.py +1 -1
  63. careamics/prediction_utils/stitch_prediction.py +1 -1
  64. careamics/transforms/n2v_manipulate_torch.py +15 -9
  65. careamics/transforms/pixel_manipulation_torch.py +59 -92
  66. careamics/utils/lightning_utils.py +2 -2
  67. careamics/utils/metrics.py +2 -1
  68. careamics/utils/torch_utils.py +23 -0
  69. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/METADATA +10 -9
  70. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/RECORD +73 -62
  71. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  72. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  73. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -1,7 +1,8 @@
1
1
  """A class to train, predict and export models in CAREamics."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union, overload
5
+ from typing import Any, Literal, Optional, Union, overload
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
@@ -827,7 +828,7 @@ class CAREamist:
827
828
  source_path = source.pred_data
828
829
  source_data_type = source.data_type
829
830
  extension_filter = source.extension_filter
830
- elif isinstance(source, (str, Path)):
831
+ elif isinstance(source, str | Path):
831
832
  source_path = source
832
833
  source_data_type = data_type or self.cfg.data_config.data_type
833
834
  extension_filter = SupportedData.get_extension_pattern(
@@ -840,7 +841,7 @@ class CAREamist:
840
841
  raise ValueError(
841
842
  "Predicting to disk is not supported for input type 'array'."
842
843
  )
843
- assert isinstance(source_path, (Path, str)) # because data_type != "array"
844
+ assert isinstance(source_path, str | Path) # because data_type != "array"
844
845
  source_path = Path(source_path)
845
846
 
846
847
  file_paths = list_files(source_path, source_data_type, extension_filter)
careamics/cli/utils.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional
4
4
 
5
5
 
6
6
  def handle_2D_3D_callback(
7
- value: Optional[tuple[int, int, int]]
7
+ value: Optional[tuple[int, int, int]],
8
8
  ) -> Optional[tuple[int, ...]]:
9
9
  """
10
10
  Callback for options that require 2D or 3D inputs.
@@ -1,4 +1,4 @@
1
- """"N2V Algorithm configuration."""
1
+ """N2V Algorithm configuration."""
2
2
 
3
3
  from typing import Annotated, Literal
4
4
 
@@ -63,6 +63,9 @@ class UNetModel(ArchitectureModel):
63
63
  """Whether information is processed independently in each channel, used to train
64
64
  channels independently."""
65
65
 
66
+ use_batch_norm: bool = Field(default=True, validate_default=True)
67
+ """Whether to use batch normalization in the model."""
68
+
66
69
  @field_validator("num_channels_init")
67
70
  @classmethod
68
71
  def validate_num_channels_init(cls, num_channels_init: int) -> int:
@@ -22,52 +22,42 @@ class CheckpointModel(BaseModel):
22
22
  https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
23
  """
24
24
 
25
- model_config = ConfigDict(
26
- validate_assignment=True,
27
- )
25
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
28
26
 
29
- monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
- """Quantity to monitor."""
27
+ monitor: Literal["val_loss"] = Field(default="val_loss")
28
+ """Quantity to monitor, currently only `val_loss`."""
31
29
 
32
- verbose: bool = Field(default=False, validate_default=True)
30
+ verbose: bool = Field(default=False)
33
31
  """Verbosity mode."""
34
32
 
35
- save_weights_only: bool = Field(default=False, validate_default=True)
33
+ save_weights_only: bool = Field(default=False)
36
34
  """When `True`, only the model's weights will be saved (model.save_weights)."""
37
35
 
38
- save_last: Optional[Literal[True, False, "link"]] = Field(
39
- default=True, validate_default=True
40
- )
36
+ save_last: Optional[Literal[True, False, "link"]] = Field(default=True)
41
37
  """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
38
 
43
- save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
39
+ save_top_k: int = Field(default=3, ge=-1, le=100)
44
40
  """If `save_top_k == kz, the best k models according to the quantity monitored
45
41
  will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
42
  all models are saved."""
47
43
 
48
- mode: Literal["min", "max"] = Field(default="min", validate_default=True)
44
+ mode: Literal["min", "max"] = Field(default="min")
49
45
  """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
46
  save file is made based on either the maximization or the minimization of the
51
47
  monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
48
  be 'min', etc.
53
49
  """
54
50
 
55
- auto_insert_metric_name: bool = Field(default=False, validate_default=True)
51
+ auto_insert_metric_name: bool = Field(default=False)
56
52
  """When `True`, the checkpoints filenames will contain the metric name."""
57
53
 
58
- every_n_train_steps: Optional[int] = Field(
59
- default=None, ge=1, le=10, validate_default=True
60
- )
54
+ every_n_train_steps: Optional[int] = Field(default=None, ge=1, le=1000)
61
55
  """Number of training steps between checkpoints."""
62
56
 
63
- train_time_interval: Optional[timedelta] = Field(
64
- default=None, validate_default=True
65
- )
57
+ train_time_interval: Optional[timedelta] = Field(default=None)
66
58
  """Checkpoints are monitored at the specified time interval."""
67
59
 
68
- every_n_epochs: Optional[int] = Field(
69
- default=None, ge=1, le=10, validate_default=True
70
- )
60
+ every_n_epochs: Optional[int] = Field(default=None, ge=1, le=100)
71
61
  """Number of epochs between checkpoints."""
72
62
 
73
63
 
@@ -83,41 +73,40 @@ class EarlyStoppingModel(BaseModel):
83
73
 
84
74
  model_config = ConfigDict(
85
75
  validate_assignment=True,
76
+ validate_default=True,
86
77
  )
87
78
 
88
- monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
79
+ monitor: Literal["val_loss"] = Field(default="val_loss")
89
80
  """Quantity to monitor."""
90
81
 
91
- min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
82
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
92
83
  """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
84
  absolute change of less than or equal to min_delta, will count as no improvement."""
94
85
 
95
- patience: int = Field(default=3, ge=1, le=10, validate_default=True)
86
+ patience: int = Field(default=3, ge=1, le=10)
96
87
  """Number of checks with no improvement after which training will be stopped."""
97
88
 
98
- verbose: bool = Field(default=False, validate_default=True)
89
+ verbose: bool = Field(default=False)
99
90
  """Verbosity mode."""
100
91
 
101
- mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
92
+ mode: Literal["min", "max", "auto"] = Field(default="min")
102
93
  """One of {min, max, auto}."""
103
94
 
104
- check_finite: bool = Field(default=True, validate_default=True)
95
+ check_finite: bool = Field(default=True)
105
96
  """When `True`, stops training when the monitored quantity becomes `NaN` or
106
97
  `inf`."""
107
98
 
108
- stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
99
+ stopping_threshold: Optional[float] = Field(default=None)
109
100
  """Stop training immediately once the monitored quantity reaches this threshold."""
110
101
 
111
- divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
102
+ divergence_threshold: Optional[float] = Field(default=None)
112
103
  """Stop training as soon as the monitored quantity becomes worse than this
113
104
  threshold."""
114
105
 
115
- check_on_train_epoch_end: Optional[bool] = Field(
116
- default=False, validate_default=True
117
- )
106
+ check_on_train_epoch_end: Optional[bool] = Field(default=False)
118
107
  """Whether to run early stopping at the end of the training epoch. If this is
119
108
  `False`, then the check runs at the end of the validation."""
120
109
 
121
- log_rank_zero_only: bool = Field(default=False, validate_default=True)
110
+ log_rank_zero_only: bool = Field(default=False)
122
111
  """When set `True`, logs the status of the early stopping callback only for rank 0
123
112
  process."""
@@ -3,9 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import re
6
+ from collections.abc import Callable
6
7
  from pprint import pformat
7
- from typing import Any, Callable, Literal, Union
8
+ from typing import Any, Literal, Union
8
9
 
10
+ import numpy as np
9
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
10
12
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11
13
  from pydantic.main import IncEx
@@ -183,6 +185,50 @@ class Configuration(BaseModel):
183
185
 
184
186
  return name
185
187
 
188
+ @model_validator(mode="after")
189
+ def validate_n2v_mask_pixel_perc(self: Self) -> Self:
190
+ """
191
+ Validate that there will always be at least one blind-spot pixel in every patch.
192
+
193
+ The probability of creating a blind-spot pixel is a function of the chosen
194
+ masked pixel percentage and patch size.
195
+
196
+ Returns
197
+ -------
198
+ Self
199
+ Validated configuration.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If the probability of masking a pixel within a patch is less than 1 for the
205
+ chosen masked pixel percentage and patch size.
206
+ """
207
+ # No validation needed for non n2v algorithms
208
+ if not isinstance(self.algorithm_config, N2VAlgorithm):
209
+ return self
210
+
211
+ mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
212
+ patch_size = self.data_config.patch_size
213
+ expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
214
+
215
+ n_dims = 3 if self.algorithm_config.model.is_3D() else 2
216
+ patch_size_lower_bound = int(np.ceil(expected_area_per_pixel ** (1 / n_dims)))
217
+ required_patch_size = tuple(
218
+ 2 ** int(np.ceil(np.log2(patch_size_lower_bound))) for _ in range(n_dims)
219
+ )
220
+ required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
221
+ if expected_area_per_pixel > np.prod(patch_size):
222
+ raise ValueError(
223
+ "The probability of creating a blind-spot pixel within a patch is "
224
+ f"below 1, for a patch size of {patch_size} with a masked pixel "
225
+ f"percentage of {mask_pixel_perc}%. Either increase the patch size to "
226
+ f"{required_patch_size} or increase the masked pixel percentage to "
227
+ f"at least {required_mask_pixel_perc}%."
228
+ )
229
+
230
+ return self
231
+
186
232
  @model_validator(mode="after")
187
233
  def validate_3D(self: Self) -> Self:
188
234
  """