careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
careamics/file_io/read/tiff.py
CHANGED
careamics/lightning/__init__.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""CAREamics PyTorch Lightning modules."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"
|
|
4
|
+
"FCNModule",
|
|
5
|
+
"VAEModule",
|
|
5
6
|
"create_careamics_module",
|
|
6
7
|
"TrainDataModule",
|
|
7
8
|
"create_train_datamodule",
|
|
@@ -12,6 +13,6 @@ __all__ = [
|
|
|
12
13
|
]
|
|
13
14
|
|
|
14
15
|
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
15
|
-
from .lightning_module import
|
|
16
|
+
from .lightning_module import FCNModule, VAEModule, create_careamics_module
|
|
16
17
|
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
17
18
|
from .train_data_module import TrainDataModule, create_train_datamodule
|
|
@@ -10,7 +10,7 @@ class HyperParametersCallback(Callback):
|
|
|
10
10
|
"""
|
|
11
11
|
Callback allowing saving CAREamics configuration as hyperparameters in the model.
|
|
12
12
|
|
|
13
|
-
This allows saving the configuration as
|
|
13
|
+
This allows saving the configuration as dictionary in the checkpoints, and
|
|
14
14
|
loading it subsequently in a CAREamist instance.
|
|
15
15
|
|
|
16
16
|
Parameters
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Optional, Union
|
|
3
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
5
6
|
import pytorch_lightning as L
|
|
6
7
|
from torch import Tensor, nn
|
|
7
8
|
|
|
8
|
-
from careamics.config import
|
|
9
|
+
from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
|
|
9
10
|
from careamics.config.support import (
|
|
10
11
|
SupportedAlgorithm,
|
|
11
12
|
SupportedArchitecture,
|
|
@@ -14,12 +15,26 @@ from careamics.config.support import (
|
|
|
14
15
|
SupportedScheduler,
|
|
15
16
|
)
|
|
16
17
|
from careamics.losses import loss_factory
|
|
18
|
+
from careamics.losses.loss_factory import LVAELossParameters
|
|
19
|
+
from careamics.models.lvae.likelihoods import (
|
|
20
|
+
GaussianLikelihood,
|
|
21
|
+
NoiseModelLikelihood,
|
|
22
|
+
likelihood_factory,
|
|
23
|
+
)
|
|
24
|
+
from careamics.models.lvae.noise_models import (
|
|
25
|
+
GaussianMixtureNoiseModel,
|
|
26
|
+
MultiChannelNoiseModel,
|
|
27
|
+
noise_model_factory,
|
|
28
|
+
)
|
|
17
29
|
from careamics.models.model_factory import model_factory
|
|
18
30
|
from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
31
|
+
from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
|
|
19
32
|
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
20
33
|
|
|
34
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
35
|
+
|
|
21
36
|
|
|
22
|
-
class
|
|
37
|
+
class FCNModule(L.LightningModule):
|
|
23
38
|
"""
|
|
24
39
|
CAREamics Lightning module.
|
|
25
40
|
|
|
@@ -45,7 +60,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
45
60
|
Learning rate scheduler name.
|
|
46
61
|
"""
|
|
47
62
|
|
|
48
|
-
def __init__(self, algorithm_config: Union[
|
|
63
|
+
def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
|
|
49
64
|
"""Lightning module for CAREamics.
|
|
50
65
|
|
|
51
66
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -59,7 +74,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
59
74
|
super().__init__()
|
|
60
75
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
61
76
|
if isinstance(algorithm_config, dict):
|
|
62
|
-
algorithm_config =
|
|
77
|
+
algorithm_config = FCNAlgorithmConfig(**algorithm_config)
|
|
63
78
|
|
|
64
79
|
# create model and loss function
|
|
65
80
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
@@ -203,7 +218,339 @@ class CAREamicsModule(L.LightningModule):
|
|
|
203
218
|
}
|
|
204
219
|
|
|
205
220
|
|
|
221
|
+
class VAEModule(L.LightningModule):
|
|
222
|
+
"""
|
|
223
|
+
CAREamics Lightning module.
|
|
224
|
+
|
|
225
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
226
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
algorithm_config : Union[VAEAlgorithmConfig, dict]
|
|
231
|
+
Algorithm configuration.
|
|
232
|
+
|
|
233
|
+
Attributes
|
|
234
|
+
----------
|
|
235
|
+
model : nn.Module
|
|
236
|
+
PyTorch model.
|
|
237
|
+
loss_func : nn.Module
|
|
238
|
+
Loss function.
|
|
239
|
+
optimizer_name : str
|
|
240
|
+
Optimizer name.
|
|
241
|
+
optimizer_params : dict
|
|
242
|
+
Optimizer parameters.
|
|
243
|
+
lr_scheduler_name : str
|
|
244
|
+
Learning rate scheduler name.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
|
|
248
|
+
"""Lightning module for CAREamics.
|
|
249
|
+
|
|
250
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
251
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
256
|
+
Algorithm configuration.
|
|
257
|
+
"""
|
|
258
|
+
super().__init__()
|
|
259
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
260
|
+
self.algorithm_config = (
|
|
261
|
+
VAEAlgorithmConfig(**algorithm_config)
|
|
262
|
+
if isinstance(algorithm_config, dict)
|
|
263
|
+
else algorithm_config
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# TODO: log algorithm config
|
|
267
|
+
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
268
|
+
|
|
269
|
+
# create model and loss function
|
|
270
|
+
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
271
|
+
self.noise_model: NoiseModel = noise_model_factory(
|
|
272
|
+
self.algorithm_config.noise_model
|
|
273
|
+
)
|
|
274
|
+
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
|
|
275
|
+
self.algorithm_config.noise_model_likelihood_model
|
|
276
|
+
)
|
|
277
|
+
self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
|
|
278
|
+
self.algorithm_config.gaussian_likelihood_model
|
|
279
|
+
)
|
|
280
|
+
self.loss_parameters = LVAELossParameters(
|
|
281
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
282
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
283
|
+
# TODO: musplit/denoisplit weights ?
|
|
284
|
+
) # type: ignore
|
|
285
|
+
self.loss_func = loss_factory(self.algorithm_config.loss)
|
|
286
|
+
|
|
287
|
+
# save optimizer and lr_scheduler names and parameters
|
|
288
|
+
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
289
|
+
self.optimizer_params = self.algorithm_config.optimizer.parameters
|
|
290
|
+
self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
|
|
291
|
+
self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
|
|
292
|
+
|
|
293
|
+
# initialize running PSNR
|
|
294
|
+
self.running_psnr = [
|
|
295
|
+
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
|
|
299
|
+
"""Forward pass.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
x : Tensor
|
|
304
|
+
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
305
|
+
number of lateral inputs.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
tuple[Tensor, dict[str, Any]]
|
|
310
|
+
A tuple with the output tensor and additional data from the top-down pass.
|
|
311
|
+
"""
|
|
312
|
+
return self.model(x) # TODO Different model can have more than one output
|
|
313
|
+
|
|
314
|
+
def training_step(
|
|
315
|
+
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
316
|
+
) -> Optional[dict[str, Tensor]]:
|
|
317
|
+
"""Training step.
|
|
318
|
+
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
batch : tuple[Tensor, Tensor]
|
|
322
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
323
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
324
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
325
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
326
|
+
muSplit/denoiSplit).
|
|
327
|
+
batch_idx : Any
|
|
328
|
+
Batch index.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
Any
|
|
333
|
+
Loss value.
|
|
334
|
+
"""
|
|
335
|
+
x, target = batch
|
|
336
|
+
|
|
337
|
+
# Forward pass
|
|
338
|
+
out = self.model(x)
|
|
339
|
+
|
|
340
|
+
# Update loss parameters
|
|
341
|
+
# TODO rethink loss parameters
|
|
342
|
+
self.loss_parameters.current_epoch = self.current_epoch
|
|
343
|
+
|
|
344
|
+
# Compute loss
|
|
345
|
+
loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
|
|
346
|
+
|
|
347
|
+
# Logging
|
|
348
|
+
# TODO: implement a separate logging method?
|
|
349
|
+
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
350
|
+
# self.log("lr", self, on_epoch=True)
|
|
351
|
+
return loss
|
|
352
|
+
|
|
353
|
+
def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
|
|
354
|
+
"""Validation step.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
batch : tuple[Tensor, Tensor]
|
|
359
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
360
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
361
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
362
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
363
|
+
muSplit/denoiSplit).
|
|
364
|
+
batch_idx : Any
|
|
365
|
+
Batch index.
|
|
366
|
+
"""
|
|
367
|
+
x, target = batch
|
|
368
|
+
|
|
369
|
+
# Forward pass
|
|
370
|
+
out = self.model(x)
|
|
371
|
+
|
|
372
|
+
# Compute loss
|
|
373
|
+
loss = self.loss_func(out, target, self.loss_parameters)
|
|
374
|
+
|
|
375
|
+
# Logging
|
|
376
|
+
# Rename val_loss dict
|
|
377
|
+
loss = {"_".join(["val", k]): v for k, v in loss.items()}
|
|
378
|
+
self.log_dict(loss, on_epoch=True, prog_bar=True)
|
|
379
|
+
curr_psnr = self.compute_val_psnr(out, target)
|
|
380
|
+
for i, psnr in enumerate(curr_psnr):
|
|
381
|
+
self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
|
|
382
|
+
|
|
383
|
+
def on_validation_epoch_end(self) -> None:
|
|
384
|
+
"""Validation epoch end."""
|
|
385
|
+
psnr_ = self.reduce_running_psnr()
|
|
386
|
+
if psnr_ is not None:
|
|
387
|
+
self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
|
|
388
|
+
else:
|
|
389
|
+
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
390
|
+
|
|
391
|
+
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
392
|
+
"""Prediction step.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
batch : Tensor
|
|
397
|
+
Input batch.
|
|
398
|
+
batch_idx : Any
|
|
399
|
+
Batch index.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
Any
|
|
404
|
+
Model output.
|
|
405
|
+
"""
|
|
406
|
+
if self._trainer.datamodule.tiled:
|
|
407
|
+
x, *aux = batch
|
|
408
|
+
else:
|
|
409
|
+
x = batch
|
|
410
|
+
aux = []
|
|
411
|
+
|
|
412
|
+
# apply test-time augmentation if available
|
|
413
|
+
# TODO: probably wont work with batch size > 1
|
|
414
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
415
|
+
tta = ImageRestorationTTA()
|
|
416
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
417
|
+
augmented_output = []
|
|
418
|
+
for augmented in augmented_batch:
|
|
419
|
+
augmented_pred = self.model(augmented)
|
|
420
|
+
augmented_output.append(augmented_pred)
|
|
421
|
+
output = tta.backward(augmented_output)
|
|
422
|
+
else:
|
|
423
|
+
output = self.model(x)
|
|
424
|
+
|
|
425
|
+
# Denormalize the output
|
|
426
|
+
denorm = Denormalize(
|
|
427
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
428
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
429
|
+
)
|
|
430
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
431
|
+
|
|
432
|
+
if len(aux) > 0: # aux can be tiling information
|
|
433
|
+
return denormalized_output, *aux
|
|
434
|
+
else:
|
|
435
|
+
return denormalized_output
|
|
436
|
+
|
|
437
|
+
def configure_optimizers(self) -> Any:
|
|
438
|
+
"""Configure optimizers and learning rate schedulers.
|
|
439
|
+
|
|
440
|
+
Returns
|
|
441
|
+
-------
|
|
442
|
+
Any
|
|
443
|
+
Optimizer and learning rate scheduler.
|
|
444
|
+
"""
|
|
445
|
+
# instantiate optimizer
|
|
446
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
447
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
448
|
+
|
|
449
|
+
# and scheduler
|
|
450
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
451
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
452
|
+
|
|
453
|
+
return {
|
|
454
|
+
"optimizer": optimizer,
|
|
455
|
+
"lr_scheduler": scheduler,
|
|
456
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
# TODO: find a way to move the following methods to a separate module
|
|
460
|
+
# TODO: this same operation is done in many other places, like in loss_func
|
|
461
|
+
# should we refactor LadderVAE so that it already outputs
|
|
462
|
+
# tuple(`mean`, `logvar`, `td_data`)?
|
|
463
|
+
def get_reconstructed_tensor(
|
|
464
|
+
self, model_outputs: tuple[Tensor, dict[str, Any]]
|
|
465
|
+
) -> Tensor:
|
|
466
|
+
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
model_outputs : tuple[Tensor, dict[str, Any]]
|
|
471
|
+
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
472
|
+
and (optionally) logvar, and the top-down data dictionary.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
Tensor
|
|
477
|
+
Reconstructed tensor, i.e., the predicted mean.
|
|
478
|
+
"""
|
|
479
|
+
predictions, _ = model_outputs
|
|
480
|
+
if self.model.predict_logvar is None:
|
|
481
|
+
return predictions
|
|
482
|
+
elif self.model.predict_logvar == "pixelwise":
|
|
483
|
+
return predictions.chunk(2, dim=1)[0]
|
|
484
|
+
|
|
485
|
+
def compute_val_psnr(
|
|
486
|
+
self,
|
|
487
|
+
model_output: tuple[Tensor, dict[str, Any]],
|
|
488
|
+
target: Tensor,
|
|
489
|
+
psnr_func: Callable = scale_invariant_psnr,
|
|
490
|
+
) -> list[float]:
|
|
491
|
+
"""Compute the PSNR for the current validation batch.
|
|
492
|
+
|
|
493
|
+
Parameters
|
|
494
|
+
----------
|
|
495
|
+
model_output : tuple[Tensor, dict[str, Any]]
|
|
496
|
+
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
497
|
+
and the top-down data dictionary.
|
|
498
|
+
target : Tensor
|
|
499
|
+
Target tensor.
|
|
500
|
+
psnr_func : Callable, optional
|
|
501
|
+
PSNR function to use, by default `scale_invariant_psnr`.
|
|
502
|
+
|
|
503
|
+
Returns
|
|
504
|
+
-------
|
|
505
|
+
list[float]
|
|
506
|
+
PSNR for each channel in the current batch.
|
|
507
|
+
"""
|
|
508
|
+
out_channels = target.shape[1]
|
|
509
|
+
|
|
510
|
+
# get the reconstructed image
|
|
511
|
+
recons_img = self.get_reconstructed_tensor(model_output)
|
|
512
|
+
|
|
513
|
+
# update running psnr
|
|
514
|
+
for i in range(out_channels):
|
|
515
|
+
self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
|
|
516
|
+
|
|
517
|
+
# compute psnr for each channel in the current batch
|
|
518
|
+
# TODO: this doesn't need do be a method of this class
|
|
519
|
+
# and hence can be moved to a separate module
|
|
520
|
+
return [
|
|
521
|
+
psnr_func(
|
|
522
|
+
gt=target[:, i].clone().detach().cpu().numpy(),
|
|
523
|
+
pred=recons_img[:, i].clone().detach().cpu().numpy(),
|
|
524
|
+
)
|
|
525
|
+
for i in range(out_channels)
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
def reduce_running_psnr(self) -> Optional[float]:
|
|
529
|
+
"""Reduce the running PSNR statistics and reset the running PSNR.
|
|
530
|
+
|
|
531
|
+
Returns
|
|
532
|
+
-------
|
|
533
|
+
Optional[float]
|
|
534
|
+
Running PSNR averaged over the different output channels.
|
|
535
|
+
"""
|
|
536
|
+
psnr_arr = [] # type: ignore
|
|
537
|
+
for i in range(len(self.running_psnr)):
|
|
538
|
+
psnr = self.running_psnr[i].get()
|
|
539
|
+
if psnr is None:
|
|
540
|
+
psnr_arr = None # type: ignore
|
|
541
|
+
break
|
|
542
|
+
psnr_arr.append(psnr.cpu().numpy())
|
|
543
|
+
self.running_psnr[i].reset()
|
|
544
|
+
# TODO: this line forces it to be a method of this class
|
|
545
|
+
# alternative is returning also the reset `running_psnr`
|
|
546
|
+
if psnr_arr is not None:
|
|
547
|
+
psnr = np.mean(psnr_arr)
|
|
548
|
+
return psnr
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
# TODO: make this LVAE compatible (?)
|
|
206
552
|
def create_careamics_module(
|
|
553
|
+
algorithm_type: Literal["fcn"],
|
|
207
554
|
algorithm: Union[SupportedAlgorithm, str],
|
|
208
555
|
loss: Union[SupportedLoss, str],
|
|
209
556
|
architecture: Union[SupportedArchitecture, str],
|
|
@@ -212,14 +559,16 @@ def create_careamics_module(
|
|
|
212
559
|
optimizer_parameters: Optional[dict] = None,
|
|
213
560
|
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
214
561
|
lr_scheduler_parameters: Optional[dict] = None,
|
|
215
|
-
) ->
|
|
216
|
-
"""Create a CAREamics
|
|
562
|
+
) -> Union[FCNModule, VAEModule]:
|
|
563
|
+
"""Create a CAREamics Lightning module.
|
|
217
564
|
|
|
218
565
|
This function exposes parameters used to create an AlgorithmModel instance,
|
|
219
566
|
triggering parameters validation.
|
|
220
567
|
|
|
221
568
|
Parameters
|
|
222
569
|
----------
|
|
570
|
+
algorithm_type : Literal["fcn"]
|
|
571
|
+
Algorithm type to use for training.
|
|
223
572
|
algorithm : SupportedAlgorithm or str
|
|
224
573
|
Algorithm to use for training (see SupportedAlgorithm).
|
|
225
574
|
loss : SupportedLoss or str
|
|
@@ -254,7 +603,8 @@ def create_careamics_module(
|
|
|
254
603
|
optimizer_parameters = {}
|
|
255
604
|
if model_parameters is None:
|
|
256
605
|
model_parameters = {}
|
|
257
|
-
algorithm_configuration = {
|
|
606
|
+
algorithm_configuration: dict[str, Any] = {
|
|
607
|
+
"algorithm_type": algorithm_type,
|
|
258
608
|
"algorithm": algorithm,
|
|
259
609
|
"loss": loss,
|
|
260
610
|
"optimizer": {
|
|
@@ -273,4 +623,10 @@ def create_careamics_module(
|
|
|
273
623
|
algorithm_configuration["model"] = model_configuration
|
|
274
624
|
|
|
275
625
|
# call the parent init using an AlgorithmModel instance
|
|
276
|
-
|
|
626
|
+
if algorithm_configuration["algorithm_type"] == "fcn":
|
|
627
|
+
return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
|
|
628
|
+
else:
|
|
629
|
+
raise NotImplementedError(
|
|
630
|
+
f"Model {algorithm_configuration['model']['architecture']} is not"
|
|
631
|
+
f"implemented or unknown."
|
|
632
|
+
)
|
|
@@ -240,7 +240,7 @@ def create_predict_datamodule(
|
|
|
240
240
|
) -> PredictDataModule:
|
|
241
241
|
"""Create a CAREamics prediction Lightning datamodule.
|
|
242
242
|
|
|
243
|
-
This function is used to
|
|
243
|
+
This function is used to explicitly pass the parameters usually contained in an
|
|
244
244
|
`inference_model` configuration.
|
|
245
245
|
|
|
246
246
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
@@ -268,7 +268,7 @@ def create_predict_datamodule(
|
|
|
268
268
|
data_type : {"array", "tiff", "custom"}
|
|
269
269
|
Data type, see `SupportedData` for available options.
|
|
270
270
|
axes : str
|
|
271
|
-
Axes of the data,
|
|
271
|
+
Axes of the data, chosen among SCZYX.
|
|
272
272
|
image_means : list of float
|
|
273
273
|
Mean values for normalization, only used if Normalization is defined.
|
|
274
274
|
image_stds : list of float
|
|
@@ -487,7 +487,7 @@ def create_train_datamodule(
|
|
|
487
487
|
) -> TrainDataModule:
|
|
488
488
|
"""Create a TrainDataModule.
|
|
489
489
|
|
|
490
|
-
This function is used to
|
|
490
|
+
This function is used to explicitly pass the parameters usually contained in a
|
|
491
491
|
`data_model` configuration to a TrainDataModule.
|
|
492
492
|
|
|
493
493
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
@@ -537,7 +537,7 @@ def create_train_datamodule(
|
|
|
537
537
|
patch_size : list of int
|
|
538
538
|
Patch size, 2D or 3D patch size.
|
|
539
539
|
axes : str
|
|
540
|
-
Axes of the data,
|
|
540
|
+
Axes of the data, chosen amongst SCZYX.
|
|
541
541
|
batch_size : int
|
|
542
542
|
Batch size.
|
|
543
543
|
val_data : pathlib.Path or str or numpy.ndarray, optional
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
3
|
+
__all__ = [
|
|
4
|
+
"loss_factory",
|
|
5
|
+
"mae_loss",
|
|
6
|
+
"mse_loss",
|
|
7
|
+
"n2v_loss",
|
|
8
|
+
"denoisplit_loss",
|
|
9
|
+
"musplit_loss",
|
|
10
|
+
"denoisplit_musplit_loss",
|
|
11
|
+
]
|
|
4
12
|
|
|
13
|
+
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
5
14
|
from .loss_factory import loss_factory
|
|
15
|
+
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""FCN losses."""
|
|
@@ -53,7 +53,7 @@ def n2v_loss(
|
|
|
53
53
|
errors = (original_patches - manipulated_patches) ** 2
|
|
54
54
|
# Average over pixels and batch
|
|
55
55
|
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
56
|
-
return loss
|
|
56
|
+
return loss # TODO change output to dict ?
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -4,14 +4,114 @@ Loss factory module.
|
|
|
4
4
|
This module contains a factory function for creating loss functions.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
|
|
11
|
+
|
|
12
|
+
from torch import Tensor as tensor
|
|
8
13
|
|
|
9
14
|
from ..config.support import SupportedLoss
|
|
10
|
-
from .losses import mae_loss, mse_loss, n2v_loss
|
|
15
|
+
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
16
|
+
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from careamics.models.lvae.likelihoods import (
|
|
20
|
+
GaussianLikelihood,
|
|
21
|
+
NoiseModelLikelihood,
|
|
22
|
+
)
|
|
23
|
+
from careamics.models.lvae.noise_models import (
|
|
24
|
+
GaussianMixtureNoiseModel,
|
|
25
|
+
MultiChannelNoiseModel,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class FCNLossParameters:
|
|
33
|
+
"""Dataclass for FCN loss."""
|
|
34
|
+
|
|
35
|
+
# TODO check
|
|
36
|
+
prediction: tensor
|
|
37
|
+
targets: tensor
|
|
38
|
+
mask: tensor
|
|
39
|
+
current_epoch: int
|
|
40
|
+
loss_weight: float
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass # TODO why not pydantic?
|
|
44
|
+
class LVAELossParameters:
|
|
45
|
+
"""Dataclass for LVAE loss."""
|
|
46
|
+
|
|
47
|
+
# TODO: refactor in more modular blocks (otherwise it gets messy very easily)
|
|
48
|
+
# e.g., - weights, - kl_params, ...
|
|
49
|
+
|
|
50
|
+
noise_model_likelihood: Optional[NoiseModelLikelihood] = None
|
|
51
|
+
"""Noise model likelihood instance."""
|
|
52
|
+
gaussian_likelihood: Optional[GaussianLikelihood] = None
|
|
53
|
+
"""Gaussian likelihood instance."""
|
|
54
|
+
current_epoch: int = 0
|
|
55
|
+
"""Current epoch in the training loop."""
|
|
56
|
+
reconstruction_weight: float = 1.0
|
|
57
|
+
"""Weight for the reconstruction loss in the total net loss
|
|
58
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
59
|
+
musplit_weight: float = 0.0
|
|
60
|
+
"""Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
61
|
+
denoisplit_weight: float = 1.0
|
|
62
|
+
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
63
|
+
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
|
|
64
|
+
"""Type of KL divergence used as KL loss."""
|
|
65
|
+
kl_weight: float = 1.0
|
|
66
|
+
"""Weight for the KL loss in the total net loss.
|
|
67
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
68
|
+
kl_annealing: bool = False
|
|
69
|
+
"""Whether to apply KL loss annealing."""
|
|
70
|
+
kl_start: int = -1
|
|
71
|
+
"""Epoch at which KL loss annealing starts."""
|
|
72
|
+
kl_annealtime: int = 10
|
|
73
|
+
"""Number of epochs for which KL loss annealing is applied."""
|
|
74
|
+
non_stochastic: bool = False
|
|
75
|
+
"""Whether to sample latents and compute KL."""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# TODO: really needed?
|
|
79
|
+
# like it is now, it is difficult to use, we need a way to specify the
|
|
80
|
+
# loss parameters in a more user-friendly way.
|
|
81
|
+
def loss_parameters_factory(
|
|
82
|
+
type: SupportedLoss,
|
|
83
|
+
) -> Union[FCNLossParameters, LVAELossParameters]:
|
|
84
|
+
"""Return loss parameters.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
type : SupportedLoss
|
|
89
|
+
Requested loss.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
Union[FCNLossParameters, LVAELossParameters]
|
|
94
|
+
Loss parameters.
|
|
95
|
+
|
|
96
|
+
Raises
|
|
97
|
+
------
|
|
98
|
+
NotImplementedError
|
|
99
|
+
If the loss is unknown.
|
|
100
|
+
"""
|
|
101
|
+
if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
|
|
102
|
+
return FCNLossParameters
|
|
103
|
+
|
|
104
|
+
elif type in [
|
|
105
|
+
SupportedLoss.MUSPLIT,
|
|
106
|
+
SupportedLoss.DENOISPLIT,
|
|
107
|
+
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
108
|
+
]:
|
|
109
|
+
return LVAELossParameters # it returns the class, not an instance
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
raise NotImplementedError(f"Loss {type} is not yet supported.")
|
|
11
113
|
|
|
12
114
|
|
|
13
|
-
# TODO add tests
|
|
14
|
-
# TODO add custom?
|
|
15
115
|
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
16
116
|
"""Return loss function.
|
|
17
117
|
|
|
@@ -42,8 +142,14 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
|
42
142
|
elif loss == SupportedLoss.MSE:
|
|
43
143
|
return mse_loss
|
|
44
144
|
|
|
45
|
-
|
|
46
|
-
|
|
145
|
+
elif loss == SupportedLoss.MUSPLIT:
|
|
146
|
+
return musplit_loss
|
|
147
|
+
|
|
148
|
+
elif loss == SupportedLoss.DENOISPLIT:
|
|
149
|
+
return denoisplit_loss
|
|
150
|
+
|
|
151
|
+
elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
|
|
152
|
+
return denoisplit_musplit_loss
|
|
47
153
|
|
|
48
154
|
else:
|
|
49
155
|
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""LVAE losses."""
|