careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from pprint import pformat
|
|
9
|
-
from typing import Annotated, Any, Literal, Union
|
|
9
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
10
10
|
from warnings import warn
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
@@ -19,7 +19,6 @@ from pydantic import (
|
|
|
19
19
|
field_validator,
|
|
20
20
|
model_validator,
|
|
21
21
|
)
|
|
22
|
-
from typing_extensions import Self
|
|
23
22
|
|
|
24
23
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
25
24
|
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
@@ -208,13 +207,12 @@ class DataConfig(BaseModel):
|
|
|
208
207
|
|
|
209
208
|
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
210
209
|
@classmethod
|
|
211
|
-
def
|
|
210
|
+
def set_default_pin_memory(
|
|
212
211
|
cls, dataloader_params: dict[str, Any]
|
|
213
212
|
) -> dict[str, Any]:
|
|
214
213
|
"""
|
|
215
|
-
Set default dataloader parameters if not provided.
|
|
214
|
+
Set default pin_memory for dataloader parameters if not provided.
|
|
216
215
|
|
|
217
|
-
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
218
216
|
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
219
217
|
|
|
220
218
|
Parameters
|
|
@@ -225,21 +223,62 @@ class DataConfig(BaseModel):
|
|
|
225
223
|
Returns
|
|
226
224
|
-------
|
|
227
225
|
dict of {str: Any}
|
|
228
|
-
The dataloader parameters with
|
|
226
|
+
The dataloader parameters with pin_memory default applied.
|
|
227
|
+
"""
|
|
228
|
+
if "pin_memory" not in dataloader_params:
|
|
229
|
+
import torch
|
|
230
|
+
|
|
231
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
232
|
+
|
|
233
|
+
return dataloader_params
|
|
234
|
+
|
|
235
|
+
@field_validator("train_dataloader_params", mode="before")
|
|
236
|
+
@classmethod
|
|
237
|
+
def set_default_train_workers(
|
|
238
|
+
cls, dataloader_params: dict[str, Any]
|
|
239
|
+
) -> dict[str, Any]:
|
|
240
|
+
"""
|
|
241
|
+
Set default num_workers for training dataloader if not provided.
|
|
242
|
+
|
|
243
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
dataloader_params : dict of {str: Any}
|
|
248
|
+
The training dataloader parameters.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
dict of {str: Any}
|
|
253
|
+
The dataloader parameters with num_workers default applied.
|
|
229
254
|
"""
|
|
230
255
|
if "num_workers" not in dataloader_params:
|
|
231
|
-
# Use
|
|
256
|
+
# Use 0 workers during tests, otherwise use all available CPU cores
|
|
232
257
|
if "pytest" in sys.modules:
|
|
233
258
|
dataloader_params["num_workers"] = 0
|
|
234
259
|
else:
|
|
235
260
|
dataloader_params["num_workers"] = os.cpu_count()
|
|
236
261
|
|
|
237
|
-
|
|
238
|
-
import torch
|
|
262
|
+
return dataloader_params
|
|
239
263
|
|
|
240
|
-
|
|
264
|
+
@model_validator(mode="after")
|
|
265
|
+
def set_val_workers_to_match_train(self: Self) -> Self:
|
|
266
|
+
"""
|
|
267
|
+
Set validation dataloader num_workers to match training dataloader.
|
|
241
268
|
|
|
242
|
-
|
|
269
|
+
If num_workers is not specified in val_dataloader_params, it will be set to the
|
|
270
|
+
same value as train_dataloader_params["num_workers"].
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Self
|
|
275
|
+
Validated data model with synchronized num_workers.
|
|
276
|
+
"""
|
|
277
|
+
if "num_workers" not in self.val_dataloader_params:
|
|
278
|
+
self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
|
|
279
|
+
"num_workers"
|
|
280
|
+
]
|
|
281
|
+
return self
|
|
243
282
|
|
|
244
283
|
@field_validator("train_dataloader_params")
|
|
245
284
|
@classmethod
|
|
@@ -2,9 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import sys
|
|
5
8
|
from collections.abc import Sequence
|
|
6
9
|
from pprint import pformat
|
|
7
|
-
from typing import Annotated, Any, Literal, Union
|
|
10
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
8
11
|
from warnings import warn
|
|
9
12
|
|
|
10
13
|
import numpy as np
|
|
@@ -17,10 +20,15 @@ from pydantic import (
|
|
|
17
20
|
field_validator,
|
|
18
21
|
model_validator,
|
|
19
22
|
)
|
|
20
|
-
from typing_extensions import Self
|
|
21
23
|
|
|
22
24
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
23
25
|
from ..validators import check_axes_validity
|
|
26
|
+
from .patch_filter import (
|
|
27
|
+
MaskFilterModel,
|
|
28
|
+
MaxFilterModel,
|
|
29
|
+
MeanSTDFilterModel,
|
|
30
|
+
ShannonFilterModel,
|
|
31
|
+
)
|
|
24
32
|
from .patching_strategies import (
|
|
25
33
|
RandomPatchingModel,
|
|
26
34
|
TiledPatchingModel,
|
|
@@ -38,6 +46,17 @@ from .patching_strategies import (
|
|
|
38
46
|
# - or is the responsibility of the creator (e.g. conveneince functions)
|
|
39
47
|
|
|
40
48
|
|
|
49
|
+
def generate_random_seed() -> int:
|
|
50
|
+
"""Generate a random seed for reproducibility.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
int
|
|
55
|
+
A random integer between 1 and 2^31 - 1.
|
|
56
|
+
"""
|
|
57
|
+
return random.randint(1, 2**31 - 1)
|
|
58
|
+
|
|
59
|
+
|
|
41
60
|
def np_float_to_scientific_str(x: float) -> str:
|
|
42
61
|
"""Return a string scientific representation of a float.
|
|
43
62
|
|
|
@@ -68,6 +87,16 @@ PatchingStrategies = Union[
|
|
|
68
87
|
]
|
|
69
88
|
"""Patching strategies."""
|
|
70
89
|
|
|
90
|
+
PatchFilters = Union[
|
|
91
|
+
MaxFilterModel,
|
|
92
|
+
MeanSTDFilterModel,
|
|
93
|
+
ShannonFilterModel,
|
|
94
|
+
]
|
|
95
|
+
"""Patch filters."""
|
|
96
|
+
|
|
97
|
+
CoordFilters = Union[MaskFilterModel] # add more here as needed
|
|
98
|
+
"""Coordinate filters."""
|
|
99
|
+
|
|
71
100
|
|
|
72
101
|
class NGDataConfig(BaseModel):
|
|
73
102
|
"""Next-Generation Dataset configuration.
|
|
@@ -106,6 +135,18 @@ class NGDataConfig(BaseModel):
|
|
|
106
135
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
107
136
|
"""Batch size for training."""
|
|
108
137
|
|
|
138
|
+
patch_filter: PatchFilters | None = Field(default=None, discriminator="name")
|
|
139
|
+
"""Patch filter to apply when using random patching. Only available during
|
|
140
|
+
training."""
|
|
141
|
+
|
|
142
|
+
coord_filter: CoordFilters | None = Field(default=None, discriminator="name")
|
|
143
|
+
"""Coordinate filter to apply when using random patching. Only available during
|
|
144
|
+
training."""
|
|
145
|
+
|
|
146
|
+
patch_filter_patience: int = Field(default=5, ge=1)
|
|
147
|
+
"""Number of consecutive patches not passing the filter before accepting the next
|
|
148
|
+
patch."""
|
|
149
|
+
|
|
109
150
|
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
110
151
|
"""Means of the data across channels, used for normalization."""
|
|
111
152
|
|
|
@@ -142,8 +183,8 @@ class NGDataConfig(BaseModel):
|
|
|
142
183
|
test_dataloader_params: dict[str, Any] = Field(default={})
|
|
143
184
|
"""Dictionary of PyTorch test dataloader parameters."""
|
|
144
185
|
|
|
145
|
-
seed: int | None = Field(
|
|
146
|
-
"""Random seed for reproducibility."""
|
|
186
|
+
seed: int | None = Field(default_factory=generate_random_seed, gt=0)
|
|
187
|
+
"""Random seed for reproducibility. If not specified, a random seed is generated."""
|
|
147
188
|
|
|
148
189
|
@field_validator("axes")
|
|
149
190
|
@classmethod
|
|
@@ -297,6 +338,129 @@ class NGDataConfig(BaseModel):
|
|
|
297
338
|
|
|
298
339
|
return self
|
|
299
340
|
|
|
341
|
+
@model_validator(mode="after")
|
|
342
|
+
def propagate_seed_to_filters(self: Self) -> Self:
|
|
343
|
+
"""
|
|
344
|
+
Propagate the main seed to patch and coordinate filters that support seeds.
|
|
345
|
+
|
|
346
|
+
This ensures that all filters use the same seed for reproducibility,
|
|
347
|
+
unless they already have a seed explicitly set.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
Self
|
|
352
|
+
Data model with propagated seeds.
|
|
353
|
+
"""
|
|
354
|
+
if self.seed is not None:
|
|
355
|
+
if self.patch_filter is not None:
|
|
356
|
+
if (
|
|
357
|
+
hasattr(self.patch_filter, "seed")
|
|
358
|
+
and self.patch_filter.seed is None
|
|
359
|
+
):
|
|
360
|
+
self.patch_filter.seed = self.seed
|
|
361
|
+
|
|
362
|
+
if self.coord_filter is not None:
|
|
363
|
+
if (
|
|
364
|
+
hasattr(self.coord_filter, "seed")
|
|
365
|
+
and self.coord_filter.seed is None
|
|
366
|
+
):
|
|
367
|
+
self.coord_filter.seed = self.seed
|
|
368
|
+
|
|
369
|
+
return self
|
|
370
|
+
|
|
371
|
+
@model_validator(mode="after")
|
|
372
|
+
def propagate_seed_to_transforms(self: Self) -> Self:
|
|
373
|
+
"""
|
|
374
|
+
Propagate the main seed to all transforms that support seeds.
|
|
375
|
+
|
|
376
|
+
This ensures that all transforms use the same seed for reproducibility,
|
|
377
|
+
unless they already have a seed explicitly set.
|
|
378
|
+
|
|
379
|
+
Returns
|
|
380
|
+
-------
|
|
381
|
+
Self
|
|
382
|
+
Data model with propagated seeds.
|
|
383
|
+
"""
|
|
384
|
+
if self.seed is not None:
|
|
385
|
+
for transform in self.transforms:
|
|
386
|
+
if hasattr(transform, "seed") and transform.seed is None:
|
|
387
|
+
transform.seed = self.seed
|
|
388
|
+
return self
|
|
389
|
+
|
|
390
|
+
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
391
|
+
@classmethod
|
|
392
|
+
def set_default_pin_memory(
|
|
393
|
+
cls, dataloader_params: dict[str, Any]
|
|
394
|
+
) -> dict[str, Any]:
|
|
395
|
+
"""
|
|
396
|
+
Set default pin_memory for dataloader parameters if not provided.
|
|
397
|
+
|
|
398
|
+
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
dataloader_params : dict of {str: Any}
|
|
403
|
+
The dataloader parameters.
|
|
404
|
+
|
|
405
|
+
Returns
|
|
406
|
+
-------
|
|
407
|
+
dict of {str: Any}
|
|
408
|
+
The dataloader parameters with pin_memory default applied.
|
|
409
|
+
"""
|
|
410
|
+
if "pin_memory" not in dataloader_params:
|
|
411
|
+
import torch
|
|
412
|
+
|
|
413
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
414
|
+
return dataloader_params
|
|
415
|
+
|
|
416
|
+
@field_validator("train_dataloader_params", mode="before")
|
|
417
|
+
@classmethod
|
|
418
|
+
def set_default_train_workers(
|
|
419
|
+
cls, dataloader_params: dict[str, Any]
|
|
420
|
+
) -> dict[str, Any]:
|
|
421
|
+
"""
|
|
422
|
+
Set default num_workers for training dataloader if not provided.
|
|
423
|
+
|
|
424
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
dataloader_params : dict of {str: Any}
|
|
429
|
+
The training dataloader parameters.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
dict of {str: Any}
|
|
434
|
+
The dataloader parameters with num_workers default applied.
|
|
435
|
+
"""
|
|
436
|
+
if "num_workers" not in dataloader_params:
|
|
437
|
+
# Use 0 workers during tests, otherwise use all available CPU cores
|
|
438
|
+
if "pytest" in sys.modules:
|
|
439
|
+
dataloader_params["num_workers"] = 0
|
|
440
|
+
else:
|
|
441
|
+
dataloader_params["num_workers"] = os.cpu_count()
|
|
442
|
+
|
|
443
|
+
return dataloader_params
|
|
444
|
+
|
|
445
|
+
@model_validator(mode="after")
|
|
446
|
+
def set_val_workers_to_match_train(self: Self) -> Self:
|
|
447
|
+
"""
|
|
448
|
+
Set validation dataloader num_workers to match training dataloader.
|
|
449
|
+
|
|
450
|
+
If num_workers is not specified in val_dataloader_params, it will be set to the
|
|
451
|
+
same value as train_dataloader_params["num_workers"].
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
Self
|
|
456
|
+
Validated data model with synchronized num_workers.
|
|
457
|
+
"""
|
|
458
|
+
if "num_workers" not in self.val_dataloader_params:
|
|
459
|
+
self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
|
|
460
|
+
"num_workers"
|
|
461
|
+
]
|
|
462
|
+
return self
|
|
463
|
+
|
|
300
464
|
def __str__(self) -> str:
|
|
301
465
|
"""
|
|
302
466
|
Pretty string reprensenting the configuration.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic models representing coordinate and patch filters."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"FilterModel",
|
|
5
|
+
"MaskFilterModel",
|
|
6
|
+
"MaxFilterModel",
|
|
7
|
+
"MeanSTDFilterModel",
|
|
8
|
+
"ShannonFilterModel",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .filter_model import FilterModel
|
|
12
|
+
from .mask_filter_model import MaskFilterModel
|
|
13
|
+
from .max_filter_model import MaxFilterModel
|
|
14
|
+
from .meanstd_filter_model import MeanSTDFilterModel
|
|
15
|
+
from .shannon_filter_model import ShannonFilterModel
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Base class for patch and coordinate filtering models."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FilterModel(BaseModel):
|
|
7
|
+
"""Base class for patch and coordinate filtering models."""
|
|
8
|
+
|
|
9
|
+
name: str
|
|
10
|
+
"""Name of the filter."""
|
|
11
|
+
|
|
12
|
+
p: float = Field(1.0, ge=0.0, le=1.0)
|
|
13
|
+
"""Probability of applying the filter to a patch or coordinate."""
|
|
14
|
+
|
|
15
|
+
seed: int | None = Field(default=None, gt=0)
|
|
16
|
+
"""Seed for the random number generator for reproducibility."""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from .filter_model import FilterModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MaskFilterModel(FilterModel):
|
|
11
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
12
|
+
|
|
13
|
+
name: Literal["mask"] = "mask"
|
|
14
|
+
"""Name of the filter."""
|
|
15
|
+
|
|
16
|
+
coverage: float = Field(0.5, ge=0.0, le=1.0)
|
|
17
|
+
"""Percentage of masked pixels required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the max patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MaxFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the max patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["max"] = "max"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Threshold for the minimum of the max-filtered patch."""
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Pydantic model for the mean std patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MeanSTDFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the mean std patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["mean_std"] = "mean_std"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
mean_threshold: float
|
|
15
|
+
"""Minimum mean intensity required to keep a patch."""
|
|
16
|
+
|
|
17
|
+
std_threshold: float | None = None
|
|
18
|
+
"""Minimum standard deviation required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ShannonFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["shannon"] = "shannon"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Minimum Shannon entropy required to keep a patch."""
|
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any, Literal, Union
|
|
5
|
+
from typing import Any, Literal, Self, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
-
from typing_extensions import Self
|
|
9
8
|
|
|
10
9
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
11
10
|
|
|
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
|
|
|
50
50
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
51
51
|
|
|
52
52
|
# TODO remove and use as parameters to the likelihood functions?
|
|
53
|
-
data_mean: Tensor =
|
|
53
|
+
data_mean: Tensor | None = None
|
|
54
54
|
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
55
55
|
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
56
56
|
|
|
57
57
|
# TODO remove and use as parameters to the likelihood functions?
|
|
58
|
-
data_std: Tensor =
|
|
58
|
+
data_std: Tensor | None = None
|
|
59
59
|
"""The standard deviation of the data, used to unnormalize data for noise
|
|
60
60
|
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
careamics/config/loss_model.py
CHANGED
|
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
|
|
|
35
35
|
validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
loss_type: Literal[
|
|
38
|
+
loss_type: Literal[
|
|
39
|
+
"hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
|
|
40
|
+
]
|
|
39
41
|
"""Type of loss to use for LVAE."""
|
|
40
42
|
|
|
41
43
|
reconstruction_weight: float = 1.0
|
|
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
|
|
|
50
52
|
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
51
53
|
kl_params: KLLossConfig = KLLossConfig()
|
|
52
54
|
"""KL loss configuration."""
|
|
53
|
-
|
|
55
|
+
# TODO revisit weights for the losses
|
|
54
56
|
# TODO: remove?
|
|
55
57
|
non_stochastic: bool = False
|
|
56
58
|
"""Whether to sample latents and compute KL."""
|
|
59
|
+
|
|
60
|
+
# TODO what are the correct parameters for HDN ?
|
careamics/config/nm_model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Noise models config."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Annotated, Literal, Union
|
|
4
|
+
from typing import Annotated, Literal, Self, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -11,6 +11,7 @@ from pydantic import (
|
|
|
11
11
|
Field,
|
|
12
12
|
PlainSerializer,
|
|
13
13
|
PlainValidator,
|
|
14
|
+
model_validator,
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
@@ -86,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
86
87
|
tol: float = Field(default=1e-10)
|
|
87
88
|
"""Tolerance used in the computation of the noise model likelihood."""
|
|
88
89
|
|
|
90
|
+
@model_validator(mode="after")
|
|
91
|
+
def validate_path(self: Self) -> Self:
|
|
92
|
+
"""Validate that the path points to a valid .npz file if provided.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
Self
|
|
97
|
+
Returns itself.
|
|
98
|
+
|
|
99
|
+
Raises
|
|
100
|
+
------
|
|
101
|
+
ValueError
|
|
102
|
+
If the path is provided but does not point to a valid .npz file.
|
|
103
|
+
"""
|
|
104
|
+
if self.path is not None:
|
|
105
|
+
path = Path(self.path)
|
|
106
|
+
if not path.exists():
|
|
107
|
+
raise ValueError(f"Path {path} does not exist.")
|
|
108
|
+
if path.suffix != ".npz":
|
|
109
|
+
raise ValueError(f"Path {path} must point to a .npz file.")
|
|
110
|
+
if not path.is_file():
|
|
111
|
+
raise ValueError(f"Path {path} must point to a file.")
|
|
112
|
+
return self
|
|
113
|
+
|
|
89
114
|
# @model_validator(mode="after")
|
|
90
115
|
# def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
91
116
|
# """Validate paths provided in the config.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal, Self
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -13,7 +13,6 @@ from pydantic import (
|
|
|
13
13
|
model_validator,
|
|
14
14
|
)
|
|
15
15
|
from torch import optim
|
|
16
|
-
from typing_extensions import Self
|
|
17
16
|
|
|
18
17
|
from careamics.utils.torch_utils import filter_parameters
|
|
19
18
|
|
|
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
|
|
|
26
26
|
MUSPLIT = "musplit"
|
|
27
27
|
"""An image splitting approach based on ladder VAE architectures."""
|
|
28
28
|
|
|
29
|
+
MICROSPLIT = "microsplit"
|
|
30
|
+
"""A micro-level image splitting approach based on ladder VAE architectures."""
|
|
31
|
+
|
|
29
32
|
DENOISPLIT = "denoisplit"
|
|
30
33
|
"""An image splitting and denoising approach based on ladder VAE architectures."""
|
|
31
34
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# SEG = "segmentation"
|
|
35
|
+
HDN = "hdn"
|
|
36
|
+
"""Hierarchical Denoising Network, an unsupervised denoising algorithm"""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Coordinate and patch filters supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedPatchFilters(str, BaseEnum):
|
|
7
|
+
"""Supported patch filters."""
|
|
8
|
+
|
|
9
|
+
MAX = "max"
|
|
10
|
+
MEANSTD = "mean_std"
|
|
11
|
+
SHANNON = "shannon"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SupportedCoordinateFilters(str, BaseEnum):
|
|
15
|
+
"""Supported coordinate filters."""
|
|
16
|
+
|
|
17
|
+
MASK = "mask"
|
|
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
|
|
|
21
21
|
MAE = "mae"
|
|
22
22
|
N2V = "n2v"
|
|
23
23
|
# PN2V = "pn2v"
|
|
24
|
-
|
|
24
|
+
HDN = "hdn"
|
|
25
25
|
MUSPLIT = "musplit"
|
|
26
|
+
MICROSPLIT = "microsplit"
|
|
26
27
|
DENOISPLIT = "denoisplit"
|
|
27
|
-
DENOISPLIT_MUSPLIT =
|
|
28
|
+
DENOISPLIT_MUSPLIT = (
|
|
29
|
+
"denoisplit_musplit" # TODO refac losses, leave only microsplit
|
|
30
|
+
)
|
|
28
31
|
# CE = "ce"
|
|
29
32
|
# DICE = "dice"
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal
|
|
6
|
+
from typing import Literal
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
9
|
|
|
10
10
|
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
11
11
|
|
|
@@ -29,26 +29,15 @@ class TrainingConfig(BaseModel):
|
|
|
29
29
|
model_config = ConfigDict(
|
|
30
30
|
validate_assignment=True,
|
|
31
31
|
)
|
|
32
|
+
lightning_trainer_config: dict | None = None
|
|
33
|
+
"""Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
|
|
34
|
+
Trainer class"""
|
|
32
35
|
|
|
33
|
-
num_epochs: int = Field(default=20, ge=1)
|
|
34
|
-
"""Number of epochs, greater than 0."""
|
|
35
|
-
|
|
36
|
-
precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
|
|
37
|
-
"""Numerical precision"""
|
|
38
|
-
max_steps: int = Field(default=-1, ge=-1)
|
|
39
|
-
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
|
-
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
|
-
"""Validation step frequency."""
|
|
42
|
-
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
43
|
-
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
44
|
-
gradient_clip_val: Union[int, float] | None = None
|
|
45
|
-
"""The value to which to clip the gradient"""
|
|
46
|
-
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
47
|
-
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
48
36
|
logger: Literal["wandb", "tensorboard"] | None = None
|
|
49
37
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
50
38
|
loggers are defined in SupportedLogger."""
|
|
51
39
|
|
|
40
|
+
# Only basic callbacks
|
|
52
41
|
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
53
42
|
"""Checkpoint callback configuration, following PyTorch Lightning Checkpoint
|
|
54
43
|
callback."""
|
|
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
|
|
|
78
67
|
Whether the logger is defined or not.
|
|
79
68
|
"""
|
|
80
69
|
return self.logger is not None
|
|
81
|
-
|
|
82
|
-
@field_validator("max_steps")
|
|
83
|
-
@classmethod
|
|
84
|
-
def validate_max_steps(cls, max_steps: int) -> int:
|
|
85
|
-
"""Validate the max_steps parameter.
|
|
86
|
-
|
|
87
|
-
Parameters
|
|
88
|
-
----------
|
|
89
|
-
max_steps : int
|
|
90
|
-
Maximum number of steps to train for. -1 means no limit.
|
|
91
|
-
|
|
92
|
-
Returns
|
|
93
|
-
-------
|
|
94
|
-
int
|
|
95
|
-
Validated max_steps.
|
|
96
|
-
"""
|
|
97
|
-
if max_steps == 0:
|
|
98
|
-
raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
|
|
99
|
-
return max_steps
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Self
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from .transform_model import TransformModel
|
|
9
8
|
|