careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import os
6
6
  import sys
7
7
  from collections.abc import Sequence
8
8
  from pprint import pformat
9
- from typing import Annotated, Any, Literal, Union
9
+ from typing import Annotated, Any, Literal, Self, Union
10
10
  from warnings import warn
11
11
 
12
12
  import numpy as np
@@ -19,7 +19,6 @@ from pydantic import (
19
19
  field_validator,
20
20
  model_validator,
21
21
  )
22
- from typing_extensions import Self
23
22
 
24
23
  from ..transformations import XYFlipModel, XYRandomRotate90Model
25
24
  from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -208,13 +207,12 @@ class DataConfig(BaseModel):
208
207
 
209
208
  @field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
210
209
  @classmethod
211
- def set_default_dataloader_params(
210
+ def set_default_pin_memory(
212
211
  cls, dataloader_params: dict[str, Any]
213
212
  ) -> dict[str, Any]:
214
213
  """
215
- Set default dataloader parameters if not provided.
214
+ Set default pin_memory for dataloader parameters if not provided.
216
215
 
217
- - If 'num_workers' is not set, it defaults to the number of available CPU cores.
218
216
  - If 'pin_memory' is not set, it defaults to True if CUDA is available.
219
217
 
220
218
  Parameters
@@ -225,21 +223,62 @@ class DataConfig(BaseModel):
225
223
  Returns
226
224
  -------
227
225
  dict of {str: Any}
228
- The dataloader parameters with defaults applied.
226
+ The dataloader parameters with pin_memory default applied.
227
+ """
228
+ if "pin_memory" not in dataloader_params:
229
+ import torch
230
+
231
+ dataloader_params["pin_memory"] = torch.cuda.is_available()
232
+
233
+ return dataloader_params
234
+
235
+ @field_validator("train_dataloader_params", mode="before")
236
+ @classmethod
237
+ def set_default_train_workers(
238
+ cls, dataloader_params: dict[str, Any]
239
+ ) -> dict[str, Any]:
240
+ """
241
+ Set default num_workers for training dataloader if not provided.
242
+
243
+ - If 'num_workers' is not set, it defaults to the number of available CPU cores.
244
+
245
+ Parameters
246
+ ----------
247
+ dataloader_params : dict of {str: Any}
248
+ The training dataloader parameters.
249
+
250
+ Returns
251
+ -------
252
+ dict of {str: Any}
253
+ The dataloader parameters with num_workers default applied.
229
254
  """
230
255
  if "num_workers" not in dataloader_params:
231
- # Use 1 worker during tests, otherwise use all available CPU cores
256
+ # Use 0 workers during tests, otherwise use all available CPU cores
232
257
  if "pytest" in sys.modules:
233
258
  dataloader_params["num_workers"] = 0
234
259
  else:
235
260
  dataloader_params["num_workers"] = os.cpu_count()
236
261
 
237
- if "pin_memory" not in dataloader_params:
238
- import torch
262
+ return dataloader_params
239
263
 
240
- dataloader_params["pin_memory"] = torch.cuda.is_available()
264
+ @model_validator(mode="after")
265
+ def set_val_workers_to_match_train(self: Self) -> Self:
266
+ """
267
+ Set validation dataloader num_workers to match training dataloader.
241
268
 
242
- return dataloader_params
269
+ If num_workers is not specified in val_dataloader_params, it will be set to the
270
+ same value as train_dataloader_params["num_workers"].
271
+
272
+ Returns
273
+ -------
274
+ Self
275
+ Validated data model with synchronized num_workers.
276
+ """
277
+ if "num_workers" not in self.val_dataloader_params:
278
+ self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
279
+ "num_workers"
280
+ ]
281
+ return self
243
282
 
244
283
  @field_validator("train_dataloader_params")
245
284
  @classmethod
@@ -2,9 +2,12 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import os
6
+ import random
7
+ import sys
5
8
  from collections.abc import Sequence
6
9
  from pprint import pformat
7
- from typing import Annotated, Any, Literal, Union
10
+ from typing import Annotated, Any, Literal, Self, Union
8
11
  from warnings import warn
9
12
 
10
13
  import numpy as np
@@ -17,10 +20,15 @@ from pydantic import (
17
20
  field_validator,
18
21
  model_validator,
19
22
  )
20
- from typing_extensions import Self
21
23
 
22
24
  from ..transformations import XYFlipModel, XYRandomRotate90Model
23
25
  from ..validators import check_axes_validity
26
+ from .patch_filter import (
27
+ MaskFilterModel,
28
+ MaxFilterModel,
29
+ MeanSTDFilterModel,
30
+ ShannonFilterModel,
31
+ )
24
32
  from .patching_strategies import (
25
33
  RandomPatchingModel,
26
34
  TiledPatchingModel,
@@ -38,6 +46,17 @@ from .patching_strategies import (
38
46
  # - or is the responsibility of the creator (e.g. conveneince functions)
39
47
 
40
48
 
49
+ def generate_random_seed() -> int:
50
+ """Generate a random seed for reproducibility.
51
+
52
+ Returns
53
+ -------
54
+ int
55
+ A random integer between 1 and 2^31 - 1.
56
+ """
57
+ return random.randint(1, 2**31 - 1)
58
+
59
+
41
60
  def np_float_to_scientific_str(x: float) -> str:
42
61
  """Return a string scientific representation of a float.
43
62
 
@@ -68,6 +87,16 @@ PatchingStrategies = Union[
68
87
  ]
69
88
  """Patching strategies."""
70
89
 
90
+ PatchFilters = Union[
91
+ MaxFilterModel,
92
+ MeanSTDFilterModel,
93
+ ShannonFilterModel,
94
+ ]
95
+ """Patch filters."""
96
+
97
+ CoordFilters = Union[MaskFilterModel] # add more here as needed
98
+ """Coordinate filters."""
99
+
71
100
 
72
101
  class NGDataConfig(BaseModel):
73
102
  """Next-Generation Dataset configuration.
@@ -106,6 +135,18 @@ class NGDataConfig(BaseModel):
106
135
  batch_size: int = Field(default=1, ge=1, validate_default=True)
107
136
  """Batch size for training."""
108
137
 
138
+ patch_filter: PatchFilters | None = Field(default=None, discriminator="name")
139
+ """Patch filter to apply when using random patching. Only available during
140
+ training."""
141
+
142
+ coord_filter: CoordFilters | None = Field(default=None, discriminator="name")
143
+ """Coordinate filter to apply when using random patching. Only available during
144
+ training."""
145
+
146
+ patch_filter_patience: int = Field(default=5, ge=1)
147
+ """Number of consecutive patches not passing the filter before accepting the next
148
+ patch."""
149
+
109
150
  image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
110
151
  """Means of the data across channels, used for normalization."""
111
152
 
@@ -142,8 +183,8 @@ class NGDataConfig(BaseModel):
142
183
  test_dataloader_params: dict[str, Any] = Field(default={})
143
184
  """Dictionary of PyTorch test dataloader parameters."""
144
185
 
145
- seed: int | None = Field(default=None, gt=0)
146
- """Random seed for reproducibility."""
186
+ seed: int | None = Field(default_factory=generate_random_seed, gt=0)
187
+ """Random seed for reproducibility. If not specified, a random seed is generated."""
147
188
 
148
189
  @field_validator("axes")
149
190
  @classmethod
@@ -297,6 +338,129 @@ class NGDataConfig(BaseModel):
297
338
 
298
339
  return self
299
340
 
341
+ @model_validator(mode="after")
342
+ def propagate_seed_to_filters(self: Self) -> Self:
343
+ """
344
+ Propagate the main seed to patch and coordinate filters that support seeds.
345
+
346
+ This ensures that all filters use the same seed for reproducibility,
347
+ unless they already have a seed explicitly set.
348
+
349
+ Returns
350
+ -------
351
+ Self
352
+ Data model with propagated seeds.
353
+ """
354
+ if self.seed is not None:
355
+ if self.patch_filter is not None:
356
+ if (
357
+ hasattr(self.patch_filter, "seed")
358
+ and self.patch_filter.seed is None
359
+ ):
360
+ self.patch_filter.seed = self.seed
361
+
362
+ if self.coord_filter is not None:
363
+ if (
364
+ hasattr(self.coord_filter, "seed")
365
+ and self.coord_filter.seed is None
366
+ ):
367
+ self.coord_filter.seed = self.seed
368
+
369
+ return self
370
+
371
+ @model_validator(mode="after")
372
+ def propagate_seed_to_transforms(self: Self) -> Self:
373
+ """
374
+ Propagate the main seed to all transforms that support seeds.
375
+
376
+ This ensures that all transforms use the same seed for reproducibility,
377
+ unless they already have a seed explicitly set.
378
+
379
+ Returns
380
+ -------
381
+ Self
382
+ Data model with propagated seeds.
383
+ """
384
+ if self.seed is not None:
385
+ for transform in self.transforms:
386
+ if hasattr(transform, "seed") and transform.seed is None:
387
+ transform.seed = self.seed
388
+ return self
389
+
390
+ @field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
391
+ @classmethod
392
+ def set_default_pin_memory(
393
+ cls, dataloader_params: dict[str, Any]
394
+ ) -> dict[str, Any]:
395
+ """
396
+ Set default pin_memory for dataloader parameters if not provided.
397
+
398
+ - If 'pin_memory' is not set, it defaults to True if CUDA is available.
399
+
400
+ Parameters
401
+ ----------
402
+ dataloader_params : dict of {str: Any}
403
+ The dataloader parameters.
404
+
405
+ Returns
406
+ -------
407
+ dict of {str: Any}
408
+ The dataloader parameters with pin_memory default applied.
409
+ """
410
+ if "pin_memory" not in dataloader_params:
411
+ import torch
412
+
413
+ dataloader_params["pin_memory"] = torch.cuda.is_available()
414
+ return dataloader_params
415
+
416
+ @field_validator("train_dataloader_params", mode="before")
417
+ @classmethod
418
+ def set_default_train_workers(
419
+ cls, dataloader_params: dict[str, Any]
420
+ ) -> dict[str, Any]:
421
+ """
422
+ Set default num_workers for training dataloader if not provided.
423
+
424
+ - If 'num_workers' is not set, it defaults to the number of available CPU cores.
425
+
426
+ Parameters
427
+ ----------
428
+ dataloader_params : dict of {str: Any}
429
+ The training dataloader parameters.
430
+
431
+ Returns
432
+ -------
433
+ dict of {str: Any}
434
+ The dataloader parameters with num_workers default applied.
435
+ """
436
+ if "num_workers" not in dataloader_params:
437
+ # Use 0 workers during tests, otherwise use all available CPU cores
438
+ if "pytest" in sys.modules:
439
+ dataloader_params["num_workers"] = 0
440
+ else:
441
+ dataloader_params["num_workers"] = os.cpu_count()
442
+
443
+ return dataloader_params
444
+
445
+ @model_validator(mode="after")
446
+ def set_val_workers_to_match_train(self: Self) -> Self:
447
+ """
448
+ Set validation dataloader num_workers to match training dataloader.
449
+
450
+ If num_workers is not specified in val_dataloader_params, it will be set to the
451
+ same value as train_dataloader_params["num_workers"].
452
+
453
+ Returns
454
+ -------
455
+ Self
456
+ Validated data model with synchronized num_workers.
457
+ """
458
+ if "num_workers" not in self.val_dataloader_params:
459
+ self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
460
+ "num_workers"
461
+ ]
462
+ return self
463
+
300
464
  def __str__(self) -> str:
301
465
  """
302
466
  Pretty string reprensenting the configuration.
@@ -0,0 +1,15 @@
1
+ """Pydantic models representing coordinate and patch filters."""
2
+
3
+ __all__ = [
4
+ "FilterModel",
5
+ "MaskFilterModel",
6
+ "MaxFilterModel",
7
+ "MeanSTDFilterModel",
8
+ "ShannonFilterModel",
9
+ ]
10
+
11
+ from .filter_model import FilterModel
12
+ from .mask_filter_model import MaskFilterModel
13
+ from .max_filter_model import MaxFilterModel
14
+ from .meanstd_filter_model import MeanSTDFilterModel
15
+ from .shannon_filter_model import ShannonFilterModel
@@ -0,0 +1,16 @@
1
+ """Base class for patch and coordinate filtering models."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class FilterModel(BaseModel):
7
+ """Base class for patch and coordinate filtering models."""
8
+
9
+ name: str
10
+ """Name of the filter."""
11
+
12
+ p: float = Field(1.0, ge=0.0, le=1.0)
13
+ """Probability of applying the filter to a patch or coordinate."""
14
+
15
+ seed: int | None = Field(default=None, gt=0)
16
+ """Seed for the random number generator for reproducibility."""
@@ -0,0 +1,17 @@
1
+ """Pydantic model for the mask coordinate filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import Field
6
+
7
+ from .filter_model import FilterModel
8
+
9
+
10
+ class MaskFilterModel(FilterModel):
11
+ """Pydantic model for the mask coordinate filter."""
12
+
13
+ name: Literal["mask"] = "mask"
14
+ """Name of the filter."""
15
+
16
+ coverage: float = Field(0.5, ge=0.0, le=1.0)
17
+ """Percentage of masked pixels required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the max patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class MaxFilterModel(FilterModel):
9
+ """Pydantic model for the max patch filter."""
10
+
11
+ name: Literal["max"] = "max"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Threshold for the minimum of the max-filtered patch."""
@@ -0,0 +1,18 @@
1
+ """Pydantic model for the mean std patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class MeanSTDFilterModel(FilterModel):
9
+ """Pydantic model for the mean std patch filter."""
10
+
11
+ name: Literal["mean_std"] = "mean_std"
12
+ """Name of the filter."""
13
+
14
+ mean_threshold: float
15
+ """Minimum mean intensity required to keep a patch."""
16
+
17
+ std_threshold: float | None = None
18
+ """Minimum standard deviation required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the Shannon entropy patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class ShannonFilterModel(FilterModel):
9
+ """Pydantic model for the Shannon entropy patch filter."""
10
+
11
+ name: Literal["shannon"] = "shannon"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Minimum Shannon entropy required to keep a patch."""
@@ -2,10 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, Literal, Union
5
+ from typing import Any, Literal, Self, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
- from typing_extensions import Self
9
8
 
10
9
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
11
10
 
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
50
50
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
51
51
 
52
52
  # TODO remove and use as parameters to the likelihood functions?
53
- data_mean: Tensor = torch.zeros(1)
53
+ data_mean: Tensor | None = None
54
54
  """The mean of the data, used to unnormalize data for noise model evaluation.
55
55
  Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
56
56
 
57
57
  # TODO remove and use as parameters to the likelihood functions?
58
- data_std: Tensor = torch.ones(1)
58
+ data_std: Tensor | None = None
59
59
  """The standard deviation of the data, used to unnormalize data for noise
60
60
  model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
35
35
  validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
36
36
  )
37
37
 
38
- loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
38
+ loss_type: Literal[
39
+ "hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
40
+ ]
39
41
  """Type of loss to use for LVAE."""
40
42
 
41
43
  reconstruction_weight: float = 1.0
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
50
52
  """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
51
53
  kl_params: KLLossConfig = KLLossConfig()
52
54
  """KL loss configuration."""
53
-
55
+ # TODO revisit weights for the losses
54
56
  # TODO: remove?
55
57
  non_stochastic: bool = False
56
58
  """Whether to sample latents and compute KL."""
59
+
60
+ # TODO what are the correct parameters for HDN ?
@@ -1,7 +1,7 @@
1
1
  """Noise models config."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Annotated, Literal, Union
4
+ from typing import Annotated, Literal, Self, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -11,6 +11,7 @@ from pydantic import (
11
11
  Field,
12
12
  PlainSerializer,
13
13
  PlainValidator,
14
+ model_validator,
14
15
  )
15
16
 
16
17
  from careamics.utils.serializers import _array_to_json, _to_numpy
@@ -86,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
86
87
  tol: float = Field(default=1e-10)
87
88
  """Tolerance used in the computation of the noise model likelihood."""
88
89
 
90
+ @model_validator(mode="after")
91
+ def validate_path(self: Self) -> Self:
92
+ """Validate that the path points to a valid .npz file if provided.
93
+
94
+ Returns
95
+ -------
96
+ Self
97
+ Returns itself.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the path is provided but does not point to a valid .npz file.
103
+ """
104
+ if self.path is not None:
105
+ path = Path(self.path)
106
+ if not path.exists():
107
+ raise ValueError(f"Path {path} does not exist.")
108
+ if path.suffix != ".npz":
109
+ raise ValueError(f"Path {path} must point to a .npz file.")
110
+ if not path.is_file():
111
+ raise ValueError(f"Path {path} must point to a file.")
112
+ return self
113
+
89
114
  # @model_validator(mode="after")
90
115
  # def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
91
116
  # """Validate paths provided in the config.
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Literal
5
+ from typing import Literal, Self
6
6
 
7
7
  from pydantic import (
8
8
  BaseModel,
@@ -13,7 +13,6 @@ from pydantic import (
13
13
  model_validator,
14
14
  )
15
15
  from torch import optim
16
- from typing_extensions import Self
17
16
 
18
17
  from careamics.utils.torch_utils import filter_parameters
19
18
 
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
26
26
  MUSPLIT = "musplit"
27
27
  """An image splitting approach based on ladder VAE architectures."""
28
28
 
29
+ MICROSPLIT = "microsplit"
30
+ """A micro-level image splitting approach based on ladder VAE architectures."""
31
+
29
32
  DENOISPLIT = "denoisplit"
30
33
  """An image splitting and denoising approach based on ladder VAE architectures."""
31
34
 
32
- # PN2V = "pn2v"
33
- # HDN = "hdn"
34
- # SEG = "segmentation"
35
+ HDN = "hdn"
36
+ """Hierarchical Denoising Network, an unsupervised denoising algorithm"""
@@ -0,0 +1,17 @@
1
+ """Coordinate and patch filters supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPatchFilters(str, BaseEnum):
7
+ """Supported patch filters."""
8
+
9
+ MAX = "max"
10
+ MEANSTD = "mean_std"
11
+ SHANNON = "shannon"
12
+
13
+
14
+ class SupportedCoordinateFilters(str, BaseEnum):
15
+ """Supported coordinate filters."""
16
+
17
+ MASK = "mask"
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
21
21
  MAE = "mae"
22
22
  N2V = "n2v"
23
23
  # PN2V = "pn2v"
24
- # HDN = "hdn"
24
+ HDN = "hdn"
25
25
  MUSPLIT = "musplit"
26
+ MICROSPLIT = "microsplit"
26
27
  DENOISPLIT = "denoisplit"
27
- DENOISPLIT_MUSPLIT = "denoisplit_musplit"
28
+ DENOISPLIT_MUSPLIT = (
29
+ "denoisplit_musplit" # TODO refac losses, leave only microsplit
30
+ )
28
31
  # CE = "ce"
29
32
  # DICE = "dice"
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Union
6
+ from typing import Literal
7
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, field_validator
8
+ from pydantic import BaseModel, ConfigDict, Field
9
9
 
10
10
  from .callback_model import CheckpointModel, EarlyStoppingModel
11
11
 
@@ -29,26 +29,15 @@ class TrainingConfig(BaseModel):
29
29
  model_config = ConfigDict(
30
30
  validate_assignment=True,
31
31
  )
32
+ lightning_trainer_config: dict | None = None
33
+ """Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
34
+ Trainer class"""
32
35
 
33
- num_epochs: int = Field(default=20, ge=1)
34
- """Number of epochs, greater than 0."""
35
-
36
- precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
37
- """Numerical precision"""
38
- max_steps: int = Field(default=-1, ge=-1)
39
- """Maximum number of steps to train for. -1 means no limit."""
40
- check_val_every_n_epoch: int = Field(default=1, ge=1)
41
- """Validation step frequency."""
42
- accumulate_grad_batches: int = Field(default=1, ge=1)
43
- """Number of batches to accumulate gradients over before stepping the optimizer."""
44
- gradient_clip_val: Union[int, float] | None = None
45
- """The value to which to clip the gradient"""
46
- gradient_clip_algorithm: Literal["value", "norm"] = "norm"
47
- """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
48
36
  logger: Literal["wandb", "tensorboard"] | None = None
49
37
  """Logger to use during training. If None, no logger will be used. Available
50
38
  loggers are defined in SupportedLogger."""
51
39
 
40
+ # Only basic callbacks
52
41
  checkpoint_callback: CheckpointModel = CheckpointModel()
53
42
  """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
54
43
  callback."""
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
78
67
  Whether the logger is defined or not.
79
68
  """
80
69
  return self.logger is not None
81
-
82
- @field_validator("max_steps")
83
- @classmethod
84
- def validate_max_steps(cls, max_steps: int) -> int:
85
- """Validate the max_steps parameter.
86
-
87
- Parameters
88
- ----------
89
- max_steps : int
90
- Maximum number of steps to train for. -1 means no limit.
91
-
92
- Returns
93
- -------
94
- int
95
- Validated max_steps.
96
- """
97
- if max_steps == 0:
98
- raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
99
- return max_steps
@@ -1,9 +1,8 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, Self
4
4
 
5
5
  from pydantic import ConfigDict, Field, model_validator
6
- from typing_extensions import Self
7
6
 
8
7
  from .transform_model import TransformModel
9
8