careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset/dataset_utils/running_stats.py +7 -3
  7. careamics/dataset_ng/README.md +212 -0
  8. careamics/dataset_ng/dataset.py +233 -0
  9. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  10. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  11. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  12. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  13. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  14. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  15. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  16. careamics/dataset_ng/factory.py +408 -0
  17. careamics/dataset_ng/legacy_interoperability.py +168 -0
  18. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  19. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  20. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  21. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  22. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  23. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  24. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  25. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  26. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  28. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  29. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  30. careamics/lightning/dataset_ng/data_module.py +488 -0
  31. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  32. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  33. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  34. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  35. careamics/lightning/lightning_module.py +3 -0
  36. careamics/lvae_training/dataset/__init__.py +8 -3
  37. careamics/lvae_training/dataset/config.py +3 -3
  38. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  39. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  40. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  41. careamics/lvae_training/dataset/types.py +3 -3
  42. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  43. careamics/lvae_training/eval_utils.py +93 -3
  44. careamics/transforms/compose.py +1 -0
  45. careamics/transforms/normalize.py +18 -7
  46. careamics/utils/lightning_utils.py +25 -11
  47. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  48. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
  49. careamics/dataset_ng/dataset/__init__.py +0 -3
  50. careamics/dataset_ng/dataset/dataset.py +0 -184
  51. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  52. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  53. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  54. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,143 @@
1
+ from typing import Any, Union
2
+
3
+ import pytorch_lightning as L
4
+ import torch
5
+ from torch import nn
6
+ from torchmetrics import MetricCollection
7
+ from torchmetrics.image import PeakSignalNoiseRatio
8
+
9
+ from careamics.config import algorithm_factory
10
+ from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
11
+ from careamics.dataset_ng.dataset import ImageRegionData
12
+ from careamics.models.unet import UNet
13
+ from careamics.transforms import Denormalize
14
+ from careamics.utils.logging import get_logger
15
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class UnetModule(L.LightningModule):
21
+ """CAREamics PyTorch Lightning module for UNet based algorithms."""
22
+
23
+ def __init__(
24
+ self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ if isinstance(algorithm_config, dict):
29
+ algorithm_config = algorithm_factory(algorithm_config)
30
+
31
+ self.config = algorithm_config
32
+ self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
33
+
34
+ self._best_checkpoint_loaded = False
35
+
36
+ # TODO: how to support metric evaluation better
37
+ self.metrics = MetricCollection(PeakSignalNoiseRatio())
38
+
39
+ def forward(self, x: Any) -> Any:
40
+ """Default forward method."""
41
+ return self.model(x)
42
+
43
+ def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
44
+ self.log(
45
+ "train_loss",
46
+ loss,
47
+ on_step=True,
48
+ on_epoch=True,
49
+ prog_bar=True,
50
+ logger=True,
51
+ batch_size=batch_size,
52
+ )
53
+
54
+ optimizer = self.optimizers()
55
+ if isinstance(optimizer, list):
56
+ current_lr = optimizer[0].param_groups[0]["lr"]
57
+ else:
58
+ current_lr = optimizer.param_groups[0]["lr"]
59
+ self.log(
60
+ "learning_rate",
61
+ current_lr,
62
+ on_step=False,
63
+ on_epoch=True,
64
+ logger=True,
65
+ batch_size=batch_size,
66
+ )
67
+
68
+ def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
69
+ self.log(
70
+ "val_loss",
71
+ loss,
72
+ on_step=False,
73
+ on_epoch=True,
74
+ prog_bar=True,
75
+ logger=True,
76
+ batch_size=batch_size,
77
+ )
78
+ self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
79
+
80
+ def _load_best_checkpoint(self) -> None:
81
+ if (
82
+ not hasattr(self.trainer, "checkpoint_callback")
83
+ or self.trainer.checkpoint_callback is None
84
+ ):
85
+ logger.warning("No checkpoint callback found, cannot load best checkpoint.")
86
+ return
87
+
88
+ best_model_path = self.trainer.checkpoint_callback.best_model_path
89
+ if best_model_path and best_model_path != "":
90
+ logger.info(f"Loading best checkpoint from: {best_model_path}")
91
+ model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
92
+ self.load_state_dict(model_state)
93
+ else:
94
+ logger.warning("No best checkpoint found.")
95
+
96
+ def predict_step(
97
+ self,
98
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
99
+ batch_idx: Any,
100
+ load_best_checkpoint=False,
101
+ ) -> Any:
102
+ """Default predict step."""
103
+ if self._best_checkpoint_loaded is False and load_best_checkpoint:
104
+ self._load_best_checkpoint()
105
+ self._best_checkpoint_loaded = True
106
+
107
+ x = batch[0]
108
+ # TODO: add TTA
109
+ prediction = self.model(x.data).cpu().numpy()
110
+
111
+ means = self._trainer.datamodule.stats.means
112
+ stds = self._trainer.datamodule.stats.stds
113
+ denormalize = Denormalize(
114
+ image_means=means,
115
+ image_stds=stds,
116
+ )
117
+ denormalized_output = denormalize(prediction)
118
+
119
+ output_batch = ImageRegionData(
120
+ data=denormalized_output,
121
+ source=x.source,
122
+ data_shape=x.data_shape,
123
+ dtype=x.dtype,
124
+ axes=x.axes,
125
+ region_spec=x.region_spec,
126
+ )
127
+ return output_batch
128
+
129
+ def configure_optimizers(self) -> Any:
130
+ """Configure optimizers."""
131
+ optimizer_func = get_optimizer(self.config.optimizer.name)
132
+ optimizer = optimizer_func(
133
+ self.model.parameters(), **self.config.optimizer.parameters
134
+ )
135
+
136
+ scheduler_func = get_scheduler(self.config.lr_scheduler.name)
137
+ scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
138
+
139
+ return {
140
+ "optimizer": optimizer,
141
+ "lr_scheduler": scheduler,
142
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
143
+ }
@@ -148,6 +148,9 @@ class FCNModule(L.LightningModule):
148
148
  self.log(
149
149
  "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
150
150
  )
151
+ optimizer = self.optimizers()
152
+ current_lr = optimizer.param_groups[0]["lr"]
153
+ self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
151
154
  return loss
152
155
 
153
156
  def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
@@ -1,14 +1,19 @@
1
- from .multich_dataset import MultiChDloader
1
+ from .config import DatasetConfig
2
2
  from .lc_dataset import LCMultiChDloader
3
+ from .ms_dataset_ref import MultiChDloaderRef
4
+ from .multich_dataset import MultiChDloader
5
+ from .multicrop_dset import MultiCropDset
3
6
  from .multifile_dataset import MultiFileDset
4
- from .config import DatasetConfig
5
- from .types import DataType, DataSplitType, TilingMode
7
+ from .types import DataSplitType, DataType, TilingMode
6
8
 
7
9
  __all__ = [
8
10
  "DatasetConfig",
9
11
  "MultiChDloader",
10
12
  "LCMultiChDloader",
11
13
  "MultiFileDset",
14
+ "MultiCropDset",
15
+ "MultiChDloaderRef",
16
+ "LCMultiChDloaderRef",
12
17
  "DataType",
13
18
  "DataSplitType",
14
19
  "TilingMode",
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional
1
+ from typing import Any, Optional, Union
2
2
 
3
3
  from pydantic import BaseModel, ConfigDict
4
4
 
@@ -43,7 +43,7 @@ class DatasetConfig(BaseModel):
43
43
  image_size: tuple # TODO: revisit, new model_config uses tuple
44
44
  """Size of one patch of data"""
45
45
 
46
- grid_size: Optional[int] = None
46
+ grid_size: Optional[Union[int, tuple[int, int, int]]] = None
47
47
  """Frame is divided into square grids of this size. A patch centered on a grid
48
48
  having size `image_size` is returned. Grid size not used in training,
49
49
  used only during val / test, grid size controls the overlap of the patches"""
@@ -82,7 +82,7 @@ class DatasetConfig(BaseModel):
82
82
  # TODO: why is this not used?
83
83
  enable_rotation_aug: Optional[bool] = False
84
84
 
85
- max_val: Optional[float] = None
85
+ max_val: Optional[Union[float, tuple]] = None
86
86
  """Maximum data in the dataset. Is calculated for train split, and should be
87
87
  externally set for val and test splits."""
88
88