careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,49 @@
1
+ """
2
+ Loss factory module.
3
+
4
+ This module contains a factory function for creating loss functions.
5
+ """
6
+
7
+ from typing import Callable, Union
8
+
9
+ from ..config.support import SupportedLoss
10
+ from .losses import mae_loss, mse_loss, n2v_loss
11
+
12
+
13
+ # TODO add tests
14
+ # TODO add custom?
15
+ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
16
+ """Return loss function.
17
+
18
+ Parameters
19
+ ----------
20
+ loss : Union[SupportedLoss, str]
21
+ Requested loss.
22
+
23
+ Returns
24
+ -------
25
+ Callable
26
+ Loss function.
27
+
28
+ Raises
29
+ ------
30
+ NotImplementedError
31
+ If the loss is unknown.
32
+ """
33
+ if loss == SupportedLoss.N2V:
34
+ return n2v_loss
35
+
36
+ # elif loss_type == SupportedLoss.PN2V:
37
+ # return pn2v_loss
38
+
39
+ elif loss == SupportedLoss.MAE:
40
+ return mae_loss
41
+
42
+ elif loss == SupportedLoss.MSE:
43
+ return mse_loss
44
+
45
+ # elif loss_type == SupportedLoss.DICE:
46
+ # return dice_loss
47
+
48
+ else:
49
+ raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -0,0 +1,98 @@
1
+ """
2
+ Loss submodule.
3
+
4
+ This submodule contains the various losses used in CAREamics.
5
+ """
6
+
7
+ import torch
8
+ from torch.nn import L1Loss, MSELoss
9
+
10
+
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Mean squared error loss.
14
+
15
+ Parameters
16
+ ----------
17
+ source : torch.Tensor
18
+ Source patches.
19
+ target : torch.Tensor
20
+ Target patches.
21
+
22
+ Returns
23
+ -------
24
+ torch.Tensor
25
+ Loss value.
26
+ """
27
+ loss = MSELoss()
28
+ return loss(source, target)
29
+
30
+
31
+ def n2v_loss(
32
+ manipulated_patches: torch.Tensor,
33
+ original_patches: torch.Tensor,
34
+ masks: torch.Tensor,
35
+ ) -> torch.Tensor:
36
+ """
37
+ N2V Loss function described in A Krull et al 2018.
38
+
39
+ Parameters
40
+ ----------
41
+ manipulated_patches : torch.Tensor
42
+ Patches with manipulated pixels.
43
+ original_patches : torch.Tensor
44
+ Noisy patches.
45
+ masks : torch.Tensor
46
+ Array containing masked pixel locations.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor
51
+ Loss value.
52
+ """
53
+ errors = (original_patches - manipulated_patches) ** 2
54
+ # Average over pixels and batch
55
+ loss = torch.sum(errors * masks) / torch.sum(masks)
56
+ return loss
57
+
58
+
59
+ def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ N2N Loss function described in to J Lehtinen et al 2018.
62
+
63
+ Parameters
64
+ ----------
65
+ samples : torch.Tensor
66
+ Raw patches.
67
+ labels : torch.Tensor
68
+ Different subset of noisy patches.
69
+
70
+ Returns
71
+ -------
72
+ torch.Tensor
73
+ Loss value.
74
+ """
75
+ loss = L1Loss()
76
+ return loss(samples, labels)
77
+
78
+
79
+ # def pn2v_loss(
80
+ # samples: torch.Tensor,
81
+ # labels: torch.Tensor,
82
+ # masks: torch.Tensor,
83
+ # noise_model: HistogramNoiseModel,
84
+ # ) -> torch.Tensor:
85
+ # """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
86
+ # likelihoods = noise_model.likelihood(labels, samples)
87
+ # likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
88
+
89
+ # # Average over pixels and batch
90
+ # loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
91
+ # return loss
92
+
93
+
94
+ # def dice_loss(
95
+ # samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
96
+ # ) -> torch.Tensor:
97
+ # """Dice loss function."""
98
+ # return DiceLoss(mode=mode)(samples, labels.long())
File without changes