careamics 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +49 -49
- careamics/cli/conf.py +6 -6
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/algorithms/vae_algorithm_model.py +4 -4
- careamics/config/callback_model.py +8 -8
- careamics/config/configuration_factories.py +49 -49
- careamics/config/data/data_model.py +7 -13
- careamics/config/data/ng_data_model.py +8 -14
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -10
- careamics/config/likelihood_model.py +2 -2
- careamics/config/nm_model.py +5 -7
- careamics/config/training_model.py +4 -4
- careamics/config/transformations/normalize_model.py +3 -3
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +12 -14
- careamics/lightning/predict_data_module.py +8 -8
- careamics/lightning/train_data_module.py +11 -11
- careamics/losses/lvae/losses.py +9 -9
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +4 -4
- careamics/models/layers.py +5 -5
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +3 -5
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/RECORD +56 -56
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Protocol, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from numpy.typing import NDArray
|
|
@@ -25,7 +25,7 @@ class WriteStrategy(Protocol):
|
|
|
25
25
|
trainer: Trainer,
|
|
26
26
|
pl_module: LightningModule,
|
|
27
27
|
prediction: Any, # TODO: change to expected type
|
|
28
|
-
batch_indices:
|
|
28
|
+
batch_indices: Sequence[int] | None,
|
|
29
29
|
batch: Any, # TODO: change to expected type
|
|
30
30
|
batch_idx: int,
|
|
31
31
|
dataloader_idx: int,
|
|
@@ -133,7 +133,7 @@ class CacheTiles(WriteStrategy):
|
|
|
133
133
|
trainer: Trainer,
|
|
134
134
|
pl_module: LightningModule,
|
|
135
135
|
prediction: tuple[NDArray, list[TileInformation]],
|
|
136
|
-
batch_indices:
|
|
136
|
+
batch_indices: Sequence[int] | None,
|
|
137
137
|
batch: tuple[NDArray, list[TileInformation]],
|
|
138
138
|
batch_idx: int,
|
|
139
139
|
dataloader_idx: int,
|
|
@@ -259,7 +259,7 @@ class WriteTilesZarr(WriteStrategy):
|
|
|
259
259
|
trainer: Trainer,
|
|
260
260
|
pl_module: LightningModule,
|
|
261
261
|
prediction: Any,
|
|
262
|
-
batch_indices:
|
|
262
|
+
batch_indices: Sequence[int] | None,
|
|
263
263
|
batch: Any,
|
|
264
264
|
batch_idx: int,
|
|
265
265
|
dataloader_idx: int,
|
|
@@ -346,7 +346,7 @@ class WriteImage(WriteStrategy):
|
|
|
346
346
|
trainer: Trainer,
|
|
347
347
|
pl_module: LightningModule,
|
|
348
348
|
prediction: NDArray,
|
|
349
|
-
batch_indices:
|
|
349
|
+
batch_indices: Sequence[int] | None,
|
|
350
350
|
batch: NDArray,
|
|
351
351
|
batch_idx: int,
|
|
352
352
|
dataloader_idx: int,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing convenience function to create `WriteStrategy`."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from careamics.config.support import SupportedData
|
|
6
6
|
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
@@ -11,9 +11,9 @@ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
|
|
|
11
11
|
def create_write_strategy(
|
|
12
12
|
write_type: SupportedWriteType,
|
|
13
13
|
tiled: bool,
|
|
14
|
-
write_func:
|
|
15
|
-
write_extension:
|
|
16
|
-
write_func_kwargs:
|
|
14
|
+
write_func: WriteFunc | None = None,
|
|
15
|
+
write_extension: str | None = None,
|
|
16
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
17
17
|
) -> WriteStrategy:
|
|
18
18
|
"""
|
|
19
19
|
Create a write strategy from convenient parameters.
|
|
@@ -78,8 +78,8 @@ def create_write_strategy(
|
|
|
78
78
|
|
|
79
79
|
def _create_tiled_write_strategy(
|
|
80
80
|
write_type: SupportedWriteType,
|
|
81
|
-
write_func:
|
|
82
|
-
write_extension:
|
|
81
|
+
write_func: WriteFunc | None,
|
|
82
|
+
write_extension: str | None,
|
|
83
83
|
write_func_kwargs: dict[str, Any],
|
|
84
84
|
) -> WriteStrategy:
|
|
85
85
|
"""
|
|
@@ -130,7 +130,7 @@ def _create_tiled_write_strategy(
|
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
def select_write_func(
|
|
133
|
-
write_type: SupportedWriteType, write_func:
|
|
133
|
+
write_type: SupportedWriteType, write_func: WriteFunc | None = None
|
|
134
134
|
) -> WriteFunc:
|
|
135
135
|
"""
|
|
136
136
|
Return a function to write images.
|
|
@@ -177,7 +177,7 @@ def select_write_func(
|
|
|
177
177
|
|
|
178
178
|
|
|
179
179
|
def select_write_extension(
|
|
180
|
-
write_type: SupportedWriteType, write_extension:
|
|
180
|
+
write_type: SupportedWriteType, write_extension: str | None = None
|
|
181
181
|
) -> str:
|
|
182
182
|
"""
|
|
183
183
|
Return an extension to add to file paths.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Union, overload
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytorch_lightning as L
|
|
@@ -124,14 +124,14 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
124
124
|
self,
|
|
125
125
|
data_config: NGDataConfig,
|
|
126
126
|
*,
|
|
127
|
-
train_data:
|
|
128
|
-
train_data_target:
|
|
129
|
-
val_data:
|
|
130
|
-
val_data_target:
|
|
131
|
-
pred_data:
|
|
132
|
-
pred_data_target:
|
|
127
|
+
train_data: InputType | None = None,
|
|
128
|
+
train_data_target: InputType | None = None,
|
|
129
|
+
val_data: InputType | None = None,
|
|
130
|
+
val_data_target: InputType | None = None,
|
|
131
|
+
pred_data: InputType | None = None,
|
|
132
|
+
pred_data_target: InputType | None = None,
|
|
133
133
|
extension_filter: str = "",
|
|
134
|
-
val_percentage:
|
|
134
|
+
val_percentage: float | None = None,
|
|
135
135
|
val_minimum_split: int = 5,
|
|
136
136
|
use_in_memory: bool = True,
|
|
137
137
|
) -> None: ...
|
|
@@ -142,16 +142,16 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
142
142
|
self,
|
|
143
143
|
data_config: NGDataConfig,
|
|
144
144
|
*,
|
|
145
|
-
train_data:
|
|
146
|
-
train_data_target:
|
|
147
|
-
val_data:
|
|
148
|
-
val_data_target:
|
|
149
|
-
pred_data:
|
|
150
|
-
pred_data_target:
|
|
145
|
+
train_data: InputType | None = None,
|
|
146
|
+
train_data_target: InputType | None = None,
|
|
147
|
+
val_data: InputType | None = None,
|
|
148
|
+
val_data_target: InputType | None = None,
|
|
149
|
+
pred_data: InputType | None = None,
|
|
150
|
+
pred_data_target: InputType | None = None,
|
|
151
151
|
read_source_func: Callable,
|
|
152
|
-
read_kwargs:
|
|
152
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
153
153
|
extension_filter: str = "",
|
|
154
|
-
val_percentage:
|
|
154
|
+
val_percentage: float | None = None,
|
|
155
155
|
val_minimum_split: int = 5,
|
|
156
156
|
use_in_memory: bool = True,
|
|
157
157
|
) -> None: ...
|
|
@@ -161,16 +161,16 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
161
161
|
self,
|
|
162
162
|
data_config: NGDataConfig,
|
|
163
163
|
*,
|
|
164
|
-
train_data:
|
|
165
|
-
train_data_target:
|
|
166
|
-
val_data:
|
|
167
|
-
val_data_target:
|
|
168
|
-
pred_data:
|
|
169
|
-
pred_data_target:
|
|
164
|
+
train_data: Any | None = None,
|
|
165
|
+
train_data_target: Any | None = None,
|
|
166
|
+
val_data: Any | None = None,
|
|
167
|
+
val_data_target: Any | None = None,
|
|
168
|
+
pred_data: Any | None = None,
|
|
169
|
+
pred_data_target: Any | None = None,
|
|
170
170
|
image_stack_loader: ImageStackLoader,
|
|
171
|
-
image_stack_loader_kwargs:
|
|
171
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
172
172
|
extension_filter: str = "",
|
|
173
|
-
val_percentage:
|
|
173
|
+
val_percentage: float | None = None,
|
|
174
174
|
val_minimum_split: int = 5,
|
|
175
175
|
use_in_memory: bool = True,
|
|
176
176
|
) -> None: ...
|
|
@@ -179,18 +179,18 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
179
179
|
self,
|
|
180
180
|
data_config: NGDataConfig,
|
|
181
181
|
*,
|
|
182
|
-
train_data:
|
|
183
|
-
train_data_target:
|
|
184
|
-
val_data:
|
|
185
|
-
val_data_target:
|
|
186
|
-
pred_data:
|
|
187
|
-
pred_data_target:
|
|
188
|
-
read_source_func:
|
|
189
|
-
read_kwargs:
|
|
190
|
-
image_stack_loader:
|
|
191
|
-
image_stack_loader_kwargs:
|
|
182
|
+
train_data: Any | None = None,
|
|
183
|
+
train_data_target: Any | None = None,
|
|
184
|
+
val_data: Any | None = None,
|
|
185
|
+
val_data_target: Any | None = None,
|
|
186
|
+
pred_data: Any | None = None,
|
|
187
|
+
pred_data_target: Any | None = None,
|
|
188
|
+
read_source_func: Callable | None = None,
|
|
189
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
190
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
191
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
192
192
|
extension_filter: str = "",
|
|
193
|
-
val_percentage:
|
|
193
|
+
val_percentage: float | None = None,
|
|
194
194
|
val_minimum_split: int = 5,
|
|
195
195
|
use_in_memory: bool = True,
|
|
196
196
|
) -> None:
|
|
@@ -280,7 +280,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
280
280
|
def _validate_input_target_type_consistency(
|
|
281
281
|
self,
|
|
282
282
|
input_data: InputType,
|
|
283
|
-
target_data:
|
|
283
|
+
target_data: InputType | None,
|
|
284
284
|
) -> None:
|
|
285
285
|
"""Validate if the input and target data types are consistent.
|
|
286
286
|
|
|
@@ -314,7 +314,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
314
314
|
self,
|
|
315
315
|
input_data,
|
|
316
316
|
target_data=None,
|
|
317
|
-
) -> tuple[list[Path],
|
|
317
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
318
318
|
"""List files from input and target directories.
|
|
319
319
|
|
|
320
320
|
Parameters
|
|
@@ -347,7 +347,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
347
347
|
self,
|
|
348
348
|
input_data,
|
|
349
349
|
target_data=None,
|
|
350
|
-
) -> tuple[list[Path],
|
|
350
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
351
351
|
"""Create a list of file paths from the input and target data.
|
|
352
352
|
|
|
353
353
|
Parameters
|
|
@@ -379,7 +379,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
379
379
|
def _validate_array_input(
|
|
380
380
|
self,
|
|
381
381
|
input_data: InputType,
|
|
382
|
-
target_data:
|
|
382
|
+
target_data: InputType | None,
|
|
383
383
|
) -> tuple[Any, Any]:
|
|
384
384
|
"""Validate if the input data is a numpy array.
|
|
385
385
|
|
|
@@ -408,8 +408,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
408
408
|
)
|
|
409
409
|
|
|
410
410
|
def _validate_path_input(
|
|
411
|
-
self, input_data: InputType, target_data:
|
|
412
|
-
) -> tuple[list[Path],
|
|
411
|
+
self, input_data: InputType, target_data: InputType | None
|
|
412
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
413
413
|
"""Validate if the input data is a path or a list of paths.
|
|
414
414
|
|
|
415
415
|
Parameters
|
|
@@ -488,8 +488,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
488
488
|
|
|
489
489
|
def _initialize_data_pair(
|
|
490
490
|
self,
|
|
491
|
-
input_data:
|
|
492
|
-
target_data:
|
|
491
|
+
input_data: InputType | None,
|
|
492
|
+
target_data: InputType | None,
|
|
493
493
|
) -> tuple[Any, Any]:
|
|
494
494
|
"""
|
|
495
495
|
Initialize a pair of input and target data.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from typing import Any, Literal,
|
|
4
|
+
from typing import Any, Literal, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
@@ -90,7 +90,7 @@ class FCNModule(L.LightningModule):
|
|
|
90
90
|
# create preprocessing, model and loss function
|
|
91
91
|
if isinstance(algorithm_config, N2VAlgorithm):
|
|
92
92
|
self.use_n2v = True
|
|
93
|
-
self.n2v_preprocess:
|
|
93
|
+
self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
|
|
94
94
|
n2v_manipulate_config=algorithm_config.n2v_config
|
|
95
95
|
)
|
|
96
96
|
else:
|
|
@@ -333,18 +333,16 @@ class VAEModule(L.LightningModule):
|
|
|
333
333
|
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
334
334
|
|
|
335
335
|
# create loss function
|
|
336
|
-
self.noise_model:
|
|
336
|
+
self.noise_model: NoiseModel | None = noise_model_factory(
|
|
337
337
|
self.algorithm_config.noise_model
|
|
338
338
|
)
|
|
339
339
|
|
|
340
|
-
self.noise_model_likelihood:
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
noise_model=self.noise_model,
|
|
344
|
-
)
|
|
340
|
+
self.noise_model_likelihood: NoiseModelLikelihood | None = likelihood_factory(
|
|
341
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
342
|
+
noise_model=self.noise_model,
|
|
345
343
|
)
|
|
346
344
|
|
|
347
|
-
self.gaussian_likelihood:
|
|
345
|
+
self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
|
|
348
346
|
self.algorithm_config.gaussian_likelihood
|
|
349
347
|
)
|
|
350
348
|
|
|
@@ -380,7 +378,7 @@ class VAEModule(L.LightningModule):
|
|
|
380
378
|
|
|
381
379
|
def training_step(
|
|
382
380
|
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
383
|
-
) ->
|
|
381
|
+
) -> dict[str, Tensor] | None:
|
|
384
382
|
"""Training step.
|
|
385
383
|
|
|
386
384
|
Parameters
|
|
@@ -603,7 +601,7 @@ class VAEModule(L.LightningModule):
|
|
|
603
601
|
for i in range(out_channels)
|
|
604
602
|
]
|
|
605
603
|
|
|
606
|
-
def reduce_running_psnr(self) ->
|
|
604
|
+
def reduce_running_psnr(self) -> float | None:
|
|
607
605
|
"""Reduce the running PSNR statistics and reset the running PSNR.
|
|
608
606
|
|
|
609
607
|
Returns
|
|
@@ -634,11 +632,11 @@ def create_careamics_module(
|
|
|
634
632
|
use_n2v2: bool = False,
|
|
635
633
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
636
634
|
struct_n2v_span: int = 5,
|
|
637
|
-
model_parameters:
|
|
635
|
+
model_parameters: dict | None = None,
|
|
638
636
|
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
639
|
-
optimizer_parameters:
|
|
637
|
+
optimizer_parameters: dict | None = None,
|
|
640
638
|
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
641
|
-
lr_scheduler_parameters:
|
|
639
|
+
lr_scheduler_parameters: dict | None = None,
|
|
642
640
|
) -> Union[FCNModule, VAEModule]:
|
|
643
641
|
"""Create a CAREamics Lightning module.
|
|
644
642
|
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytorch_lightning as L
|
|
@@ -65,9 +65,9 @@ class PredictDataModule(L.LightningDataModule):
|
|
|
65
65
|
self,
|
|
66
66
|
pred_config: InferenceConfig,
|
|
67
67
|
pred_data: Union[Path, str, NDArray],
|
|
68
|
-
read_source_func:
|
|
68
|
+
read_source_func: Callable | None = None,
|
|
69
69
|
extension_filter: str = "",
|
|
70
|
-
dataloader_params:
|
|
70
|
+
dataloader_params: dict | None = None,
|
|
71
71
|
) -> None:
|
|
72
72
|
"""
|
|
73
73
|
Constructor.
|
|
@@ -173,7 +173,7 @@ class PredictDataModule(L.LightningDataModule):
|
|
|
173
173
|
self.pred_data, self.data_type, self.extension_filter
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
-
def setup(self, stage:
|
|
176
|
+
def setup(self, stage: str | None = None) -> None:
|
|
177
177
|
"""
|
|
178
178
|
Hook called at the beginning of predict.
|
|
179
179
|
|
|
@@ -231,13 +231,13 @@ def create_predict_datamodule(
|
|
|
231
231
|
axes: str,
|
|
232
232
|
image_means: list[float],
|
|
233
233
|
image_stds: list[float],
|
|
234
|
-
tile_size:
|
|
235
|
-
tile_overlap:
|
|
234
|
+
tile_size: tuple[int, ...] | None = None,
|
|
235
|
+
tile_overlap: tuple[int, ...] | None = None,
|
|
236
236
|
batch_size: int = 1,
|
|
237
237
|
tta_transforms: bool = True,
|
|
238
|
-
read_source_func:
|
|
238
|
+
read_source_func: Callable | None = None,
|
|
239
239
|
extension_filter: str = "",
|
|
240
|
-
dataloader_params:
|
|
240
|
+
dataloader_params: dict | None = None,
|
|
241
241
|
) -> PredictDataModule:
|
|
242
242
|
"""Create a CAREamics prediction Lightning datamodule.
|
|
243
243
|
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytorch_lightning as L
|
|
@@ -121,10 +121,10 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
121
121
|
self,
|
|
122
122
|
data_config: DataConfig,
|
|
123
123
|
train_data: Union[Path, str, NDArray],
|
|
124
|
-
val_data:
|
|
125
|
-
train_data_target:
|
|
126
|
-
val_data_target:
|
|
127
|
-
read_source_func:
|
|
124
|
+
val_data: Union[Path, str, NDArray] | None = None,
|
|
125
|
+
train_data_target: Union[Path, str, NDArray] | None = None,
|
|
126
|
+
val_data_target: Union[Path, str, NDArray] | None = None,
|
|
127
|
+
read_source_func: Callable | None = None,
|
|
128
128
|
extension_filter: str = "",
|
|
129
129
|
val_percentage: float = 0.1,
|
|
130
130
|
val_minimum_split: int = 5,
|
|
@@ -477,15 +477,15 @@ def create_train_datamodule(
|
|
|
477
477
|
patch_size: list[int],
|
|
478
478
|
axes: str,
|
|
479
479
|
batch_size: int,
|
|
480
|
-
val_data:
|
|
481
|
-
transforms:
|
|
482
|
-
train_target_data:
|
|
483
|
-
val_target_data:
|
|
484
|
-
read_source_func:
|
|
480
|
+
val_data: Union[str, Path, NDArray] | None = None,
|
|
481
|
+
transforms: list[TransformModel] | None = None,
|
|
482
|
+
train_target_data: Union[str, Path, NDArray] | None = None,
|
|
483
|
+
val_target_data: Union[str, Path, NDArray] | None = None,
|
|
484
|
+
read_source_func: Callable | None = None,
|
|
485
485
|
extension_filter: str = "",
|
|
486
486
|
val_percentage: float = 0.1,
|
|
487
487
|
val_minimum_patches: int = 5,
|
|
488
|
-
dataloader_params:
|
|
488
|
+
dataloader_params: dict | None = None,
|
|
489
489
|
use_in_memory: bool = True,
|
|
490
490
|
) -> TrainDataModule:
|
|
491
491
|
"""Create a TrainDataModule.
|
careamics/losses/lvae/losses.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Literal,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import torch
|
|
@@ -112,7 +112,7 @@ def get_kl_divergence_loss(
|
|
|
112
112
|
rescaling: Literal["latent_dim", "image_dim"],
|
|
113
113
|
aggregation: Literal["mean", "sum"],
|
|
114
114
|
free_bits_coeff: float,
|
|
115
|
-
img_shape:
|
|
115
|
+
img_shape: tuple[int] | None = None,
|
|
116
116
|
) -> torch.Tensor:
|
|
117
117
|
"""Compute the KL divergence loss.
|
|
118
118
|
|
|
@@ -273,9 +273,9 @@ def musplit_loss(
|
|
|
273
273
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
274
274
|
targets: torch.Tensor,
|
|
275
275
|
config: LVAELossConfig,
|
|
276
|
-
gaussian_likelihood:
|
|
277
|
-
noise_model_likelihood:
|
|
278
|
-
) ->
|
|
276
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
277
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
|
|
278
|
+
) -> dict[str, torch.Tensor] | None:
|
|
279
279
|
"""Loss function for muSplit.
|
|
280
280
|
|
|
281
281
|
Parameters
|
|
@@ -351,9 +351,9 @@ def denoisplit_loss(
|
|
|
351
351
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
352
352
|
targets: torch.Tensor,
|
|
353
353
|
config: LVAELossConfig,
|
|
354
|
-
gaussian_likelihood:
|
|
355
|
-
noise_model_likelihood:
|
|
356
|
-
) ->
|
|
354
|
+
gaussian_likelihood: GaussianLikelihood | None = None,
|
|
355
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None,
|
|
356
|
+
) -> dict[str, torch.Tensor] | None:
|
|
357
357
|
"""Loss function for DenoiSplit.
|
|
358
358
|
|
|
359
359
|
Parameters
|
|
@@ -430,7 +430,7 @@ def denoisplit_musplit_loss(
|
|
|
430
430
|
config: LVAELossConfig,
|
|
431
431
|
gaussian_likelihood: GaussianLikelihood,
|
|
432
432
|
noise_model_likelihood: NoiseModelLikelihood,
|
|
433
|
-
) ->
|
|
433
|
+
) -> dict[str, torch.Tensor] | None:
|
|
434
434
|
"""Loss function for DenoiSplit.
|
|
435
435
|
|
|
436
436
|
Parameters
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Module use to build BMZ model description."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
-
from bioimageio.spec._internal.io import
|
|
7
|
+
from bioimageio.spec._internal.io import extract
|
|
8
8
|
from bioimageio.spec.model.v0_5 import (
|
|
9
9
|
ArchitectureFromLibraryDescr,
|
|
10
10
|
Author,
|
|
@@ -12,7 +12,6 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
12
12
|
AxisId,
|
|
13
13
|
BatchAxis,
|
|
14
14
|
ChannelAxis,
|
|
15
|
-
EnvironmentFileDescr,
|
|
16
15
|
FileDescr,
|
|
17
16
|
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
18
17
|
FixedZeroMeanUnitVarianceDescr,
|
|
@@ -36,7 +35,7 @@ from ._readme_factory import readme_factory
|
|
|
36
35
|
def _create_axes(
|
|
37
36
|
array: np.ndarray,
|
|
38
37
|
data_config: DataConfig,
|
|
39
|
-
channel_names:
|
|
38
|
+
channel_names: list[str] | None = None,
|
|
40
39
|
is_input: bool = True,
|
|
41
40
|
) -> list[AxisBase]:
|
|
42
41
|
"""Create axes description.
|
|
@@ -105,7 +104,7 @@ def _create_inputs_ouputs(
|
|
|
105
104
|
data_config: DataConfig,
|
|
106
105
|
input_path: Union[Path, str],
|
|
107
106
|
output_path: Union[Path, str],
|
|
108
|
-
channel_names:
|
|
107
|
+
channel_names: list[str] | None = None,
|
|
109
108
|
) -> tuple[InputTensorDescr, OutputTensorDescr]:
|
|
110
109
|
"""Create input and output tensor description.
|
|
111
110
|
|
|
@@ -197,7 +196,7 @@ def create_model_description(
|
|
|
197
196
|
config_path: Union[Path, str],
|
|
198
197
|
env_path: Union[Path, str],
|
|
199
198
|
covers: list[Union[Path, str]],
|
|
200
|
-
channel_names:
|
|
199
|
+
channel_names: list[str] | None = None,
|
|
201
200
|
model_version: str = "0.1.0",
|
|
202
201
|
) -> ModelDescr:
|
|
203
202
|
"""Create model description.
|
|
@@ -269,7 +268,7 @@ def create_model_description(
|
|
|
269
268
|
source=weights_path,
|
|
270
269
|
architecture=architecture_descr,
|
|
271
270
|
pytorch_version=Version(torch_version),
|
|
272
|
-
dependencies=
|
|
271
|
+
dependencies=FileDescr(source=Path(env_path)),
|
|
273
272
|
),
|
|
274
273
|
)
|
|
275
274
|
|
|
@@ -322,9 +321,11 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
|
|
|
322
321
|
"""
|
|
323
322
|
if model_desc.weights.pytorch_state_dict is None:
|
|
324
323
|
raise ValueError("No model weights found in model description.")
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
324
|
+
|
|
325
|
+
# extract the zip model and return the directory
|
|
326
|
+
model_dir = extract(model_desc.root)
|
|
327
|
+
|
|
328
|
+
weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
|
|
328
329
|
|
|
329
330
|
for file in model_desc.attachments:
|
|
330
331
|
file_path = file.source if isinstance(file.source, Path) else file.source.path
|
|
@@ -332,7 +333,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
|
|
|
332
333
|
continue
|
|
333
334
|
file_path = Path(file_path)
|
|
334
335
|
if file_path.name == "careamics.yaml":
|
|
335
|
-
config_path =
|
|
336
|
+
config_path = model_dir.joinpath(file.source.path)
|
|
336
337
|
break
|
|
337
338
|
else:
|
|
338
339
|
raise ValueError("Configuration file not found.")
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from bioimageio.core import load_model_description, test_model
|
|
@@ -90,8 +90,8 @@ def export_to_bmz(
|
|
|
90
90
|
authors: list[dict],
|
|
91
91
|
input_array: np.ndarray,
|
|
92
92
|
output_array: np.ndarray,
|
|
93
|
-
covers:
|
|
94
|
-
channel_names:
|
|
93
|
+
covers: list[Union[Path, str]] | None = None,
|
|
94
|
+
channel_names: list[str] | None = None,
|
|
95
95
|
model_version: str = "0.1.0",
|
|
96
96
|
) -> None:
|
|
97
97
|
"""Export the model to BioImage Model Zoo format.
|
|
@@ -187,7 +187,7 @@ def export_to_bmz(
|
|
|
187
187
|
|
|
188
188
|
# test model description
|
|
189
189
|
test_kwargs = (
|
|
190
|
-
model_description.config.
|
|
190
|
+
model_description.config.bioimageio.model_dump()
|
|
191
191
|
.get("test_kwargs", {})
|
|
192
192
|
.get("pytorch_state_dict", {})
|
|
193
193
|
)
|
careamics/models/layers.py
CHANGED
|
@@ -4,7 +4,7 @@ Layer module.
|
|
|
4
4
|
This submodule contains layers used in the CAREamics models.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
@@ -207,8 +207,8 @@ def get_pascal_kernel_1d(
|
|
|
207
207
|
kernel_size: int,
|
|
208
208
|
norm: bool = False,
|
|
209
209
|
*,
|
|
210
|
-
device:
|
|
211
|
-
dtype:
|
|
210
|
+
device: torch.device | None = None,
|
|
211
|
+
dtype: torch.dtype | None = None,
|
|
212
212
|
) -> torch.Tensor:
|
|
213
213
|
"""Generate Yang Hui triangle (Pascal's triangle) for a given number.
|
|
214
214
|
|
|
@@ -270,8 +270,8 @@ def _get_pascal_kernel_nd(
|
|
|
270
270
|
norm: bool = True,
|
|
271
271
|
dim: int = 2,
|
|
272
272
|
*,
|
|
273
|
-
device:
|
|
274
|
-
dtype:
|
|
273
|
+
device: torch.device | None = None,
|
|
274
|
+
dtype: torch.dtype | None = None,
|
|
275
275
|
) -> torch.Tensor:
|
|
276
276
|
"""Generate pascal filter kernel by kernel size.
|
|
277
277
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing pytorch implementations for obtaining predictions from an LVAE."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -18,7 +18,7 @@ def lvae_predict_single_sample(
|
|
|
18
18
|
model: LVAE,
|
|
19
19
|
likelihood_obj: LikelihoodModule,
|
|
20
20
|
input: torch.Tensor,
|
|
21
|
-
) -> tuple[torch.Tensor,
|
|
21
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
22
22
|
"""
|
|
23
23
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
24
24
|
|
|
@@ -57,7 +57,7 @@ def lvae_predict_tiled_batch(
|
|
|
57
57
|
model: LVAE,
|
|
58
58
|
likelihood_obj: LikelihoodModule,
|
|
59
59
|
input: tuple[Any],
|
|
60
|
-
) -> tuple[tuple[Any],
|
|
60
|
+
) -> tuple[tuple[Any], tuple[Any] | None]:
|
|
61
61
|
# TODO: fix docstring return types, ... too many output options
|
|
62
62
|
"""
|
|
63
63
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
@@ -98,7 +98,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
98
98
|
likelihood_obj: LikelihoodModule,
|
|
99
99
|
input: tuple[Any],
|
|
100
100
|
mmse_count: int,
|
|
101
|
-
) -> tuple[tuple[Any], tuple[Any],
|
|
101
|
+
) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
|
|
102
102
|
# TODO: fix docstring return types, ... hard to make readable
|
|
103
103
|
"""
|
|
104
104
|
Generate the MMSE (minimum mean squared error) prediction, for a given input.
|
|
@@ -137,7 +137,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
137
137
|
|
|
138
138
|
input_shape = x.shape
|
|
139
139
|
output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
|
|
140
|
-
log_var:
|
|
140
|
+
log_var: torch.Tensor | None = None
|
|
141
141
|
# pre-declare empty array to fill with individual sample predictions
|
|
142
142
|
sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
|
|
143
143
|
for mmse_idx in range(mmse_count):
|