careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,229 @@
1
+ from collections.abc import Sequence
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ from typing import Any, Generic, Literal, NamedTuple, Optional, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from torch.utils.data import Dataset
9
+ from tqdm.auto import tqdm
10
+
11
+ from careamics.config.data.ng_data_model import NGDataConfig
12
+ from careamics.config.support.supported_patching_strategies import (
13
+ SupportedPatchingStrategy,
14
+ )
15
+ from careamics.config.transformations import NormalizeModel
16
+ from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
17
+ from careamics.dataset.patching.patching import Stats
18
+ from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
19
+ from careamics.dataset_ng.patching_strategies import (
20
+ FixedRandomPatchingStrategy,
21
+ PatchingStrategy,
22
+ PatchSpecs,
23
+ RandomPatchingStrategy,
24
+ TilingStrategy,
25
+ WholeSamplePatchingStrategy,
26
+ )
27
+ from careamics.transforms import Compose
28
+
29
+
30
+ class Mode(str, Enum):
31
+ TRAINING = "training"
32
+ VALIDATING = "validating"
33
+ PREDICTING = "predicting"
34
+
35
+
36
+ class ImageRegionData(NamedTuple):
37
+ data: NDArray
38
+ source: Union[str, Literal["array"]]
39
+ data_shape: Sequence[int]
40
+ dtype: str # dtype should be str for collate
41
+ axes: str
42
+ region_spec: PatchSpecs
43
+
44
+
45
+ InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
46
+
47
+
48
+ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
49
+ def __init__(
50
+ self,
51
+ data_config: NGDataConfig,
52
+ mode: Mode,
53
+ input_extractor: PatchExtractor[GenericImageStack],
54
+ target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
55
+ ):
56
+ self.config = data_config
57
+ self.mode = mode
58
+
59
+ self.input_extractor = input_extractor
60
+ self.target_extractor = target_extractor
61
+
62
+ self.patching_strategy = self._initialize_patching_strategy()
63
+
64
+ self.input_stats, self.target_stats = self._initialize_statistics()
65
+
66
+ self.transforms = self._initialize_transforms()
67
+
68
+ def _initialize_patching_strategy(self) -> PatchingStrategy:
69
+ patching_strategy: PatchingStrategy
70
+ if self.mode == Mode.TRAINING:
71
+ if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
72
+ raise ValueError(
73
+ f"Only `random` patching strategy supported during training, got "
74
+ f"{self.config.patching.name}."
75
+ )
76
+
77
+ patching_strategy = RandomPatchingStrategy(
78
+ data_shapes=self.input_extractor.shape,
79
+ patch_size=self.config.patching.patch_size,
80
+ seed=self.config.seed,
81
+ )
82
+ elif self.mode == Mode.VALIDATING:
83
+ if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
84
+ raise ValueError(
85
+ f"Only `random` patching strategy supported during training, got "
86
+ f"{self.config.patching.name}."
87
+ )
88
+
89
+ patching_strategy = FixedRandomPatchingStrategy(
90
+ data_shapes=self.input_extractor.shape,
91
+ patch_size=self.config.patching.patch_size,
92
+ seed=self.config.seed,
93
+ )
94
+ elif self.mode == Mode.PREDICTING:
95
+ if (
96
+ self.config.patching.name != SupportedPatchingStrategy.TILED
97
+ and self.config.patching.name != SupportedPatchingStrategy.WHOLE
98
+ ):
99
+ raise ValueError(
100
+ f"Only `tiled` and `whole` patching strategy supported during "
101
+ f"training, got {self.config.patching.name}."
102
+ )
103
+ elif self.config.patching.name == SupportedPatchingStrategy.TILED:
104
+ patching_strategy = TilingStrategy(
105
+ data_shapes=self.input_extractor.shape,
106
+ tile_size=self.config.patching.patch_size,
107
+ overlaps=self.config.patching.overlaps,
108
+ )
109
+ else:
110
+ patching_strategy = WholeSamplePatchingStrategy(
111
+ data_shapes=self.input_extractor.shape
112
+ )
113
+ else:
114
+ raise ValueError(f"Unrecognised dataset mode {self.mode}.")
115
+
116
+ return patching_strategy
117
+
118
+ def _initialize_transforms(self) -> Optional[Compose]:
119
+ normalize = NormalizeModel(
120
+ image_means=self.input_stats.means,
121
+ image_stds=self.input_stats.stds,
122
+ target_means=self.target_stats.means,
123
+ target_stds=self.target_stats.stds,
124
+ )
125
+ if self.mode == Mode.TRAINING:
126
+ # TODO: initialize normalization separately depending on configuration
127
+ return Compose(transform_list=[normalize] + list(self.config.transforms))
128
+
129
+ # TODO: add TTA
130
+ return Compose(transform_list=[normalize])
131
+
132
+ def _calculate_stats(
133
+ self, data_extractor: PatchExtractor[GenericImageStack]
134
+ ) -> Stats:
135
+ image_stats = WelfordStatistics()
136
+ n_patches = self.patching_strategy.n_patches
137
+
138
+ for idx in tqdm(range(n_patches), desc="Computing statistics"):
139
+ patch_spec = self.patching_strategy.get_patch_spec(idx)
140
+ patch = data_extractor.extract_patch(
141
+ data_idx=patch_spec["data_idx"],
142
+ sample_idx=patch_spec["sample_idx"],
143
+ coords=patch_spec["coords"],
144
+ patch_size=patch_spec["patch_size"],
145
+ )
146
+ # TODO: statistics accept SCYX format, while patch is CYX
147
+ image_stats.update(patch[None, ...], sample_idx=idx)
148
+
149
+ image_means, image_stds = image_stats.finalize()
150
+ return Stats(image_means, image_stds)
151
+
152
+ # TODO: add running stats
153
+ def _initialize_statistics(self) -> tuple[Stats, Stats]:
154
+ if self.config.image_means is not None and self.config.image_stds is not None:
155
+ input_stats = Stats(self.config.image_means, self.config.image_stds)
156
+ else:
157
+ input_stats = self._calculate_stats(self.input_extractor)
158
+
159
+ target_stats = Stats((), ())
160
+
161
+ if self.config.target_means is not None and self.config.target_stds is not None:
162
+ target_stats = Stats(self.config.target_means, self.config.target_stds)
163
+ elif self.target_extractor is not None:
164
+ target_stats = self._calculate_stats(self.target_extractor)
165
+
166
+ return input_stats, target_stats
167
+
168
+ def __len__(self):
169
+ return self.patching_strategy.n_patches
170
+
171
+ def _create_image_region(
172
+ self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
173
+ ) -> ImageRegionData:
174
+ data_idx = patch_spec["data_idx"]
175
+ source = extractor.image_stacks[data_idx].source
176
+ return ImageRegionData(
177
+ data=patch,
178
+ source=str(source),
179
+ dtype=str(extractor.image_stacks[data_idx].data_dtype),
180
+ data_shape=extractor.image_stacks[data_idx].data_shape,
181
+ # TODO: should it be axes of the original image instead?
182
+ axes=self.config.axes,
183
+ region_spec=patch_spec,
184
+ )
185
+
186
+ def __getitem__(
187
+ self, index: int
188
+ ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
189
+ patch_spec = self.patching_strategy.get_patch_spec(index)
190
+ input_patch = self.input_extractor.extract_patch(
191
+ data_idx=patch_spec["data_idx"],
192
+ sample_idx=patch_spec["sample_idx"],
193
+ coords=patch_spec["coords"],
194
+ patch_size=patch_spec["patch_size"],
195
+ )
196
+
197
+ target_patch = (
198
+ self.target_extractor.extract_patch(
199
+ data_idx=patch_spec["data_idx"],
200
+ sample_idx=patch_spec["sample_idx"],
201
+ coords=patch_spec["coords"],
202
+ patch_size=patch_spec["patch_size"],
203
+ )
204
+ if self.target_extractor is not None
205
+ else None
206
+ )
207
+
208
+ if self.transforms is not None:
209
+ if self.target_extractor is not None:
210
+ input_patch, target_patch = self.transforms(input_patch, target_patch)
211
+ else:
212
+ # TODO: compose doesn't return None for target patch anymore
213
+ # so have to do this annoying if else
214
+ (input_patch,) = self.transforms(input_patch, target_patch)
215
+ target_patch = None
216
+
217
+ input_data = self._create_image_region(
218
+ patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
219
+ )
220
+
221
+ if target_patch is not None and self.target_extractor is not None:
222
+ target_data = self._create_image_region(
223
+ patch=target_patch,
224
+ patch_spec=patch_spec,
225
+ extractor=self.target_extractor,
226
+ )
227
+ return input_data, target_data
228
+ else:
229
+ return (input_data,)
@@ -0,0 +1,361 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import numpy as np\n",
13
+ "import tifffile\n",
14
+ "from careamics_portfolio import PortfolioManager\n",
15
+ "\n",
16
+ "from careamics.config.configuration_factories import (\n",
17
+ " _create_ng_data_configuration,\n",
18
+ " create_n2v_configuration,\n",
19
+ ")\n",
20
+ "from careamics.config.data import NGDataConfig\n",
21
+ "from careamics.lightning.callbacks import HyperParametersCallback\n",
22
+ "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
23
+ "from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# Set seeds for reproducibility\n",
33
+ "from pytorch_lightning import seed_everything\n",
34
+ "\n",
35
+ "seed = 42\n",
36
+ "seed_everything(seed)"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "### Load data and set paths to it"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# instantiate data portfolio manage and download the data\n",
53
+ "root_path = Path(\"./data\")\n",
54
+ "\n",
55
+ "portfolio = PortfolioManager()\n",
56
+ "files = portfolio.denoising.N2V_BSD68.download(root_path)\n",
57
+ "\n",
58
+ "# create paths for the data\n",
59
+ "data_path = Path(root_path / \"denoising-N2V_BSD68.unzip/BSD68_reproducibility_data\")\n",
60
+ "train_path = data_path / \"train\"\n",
61
+ "val_path = data_path / \"val\"\n",
62
+ "test_path = data_path / \"test\" / \"images\"\n",
63
+ "gt_path = data_path / \"test\" / \"gt\"\n",
64
+ "\n",
65
+ "# list train, val and test files\n",
66
+ "train_files = sorted(train_path.rglob(\"*.tiff\"))\n",
67
+ "val_files = sorted(val_path.rglob(\"*.tiff\"))\n",
68
+ "test_files = sorted(test_path.rglob(\"*.tiff\"))"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {},
74
+ "source": [
75
+ "### Visualize a single train and val image"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "# load training and validation image and show them side by side\n",
85
+ "single_train_image = tifffile.imread(train_files[0])[0]\n",
86
+ "single_val_image = tifffile.imread(val_files[0])[0]\n",
87
+ "\n",
88
+ "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
89
+ "ax[0].imshow(single_train_image, cmap=\"gray\")\n",
90
+ "ax[0].set_title(\"Training Image\")\n",
91
+ "ax[1].imshow(single_val_image, cmap=\"gray\")\n",
92
+ "ax[1].set_title(\"Validation Image\")"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "### Create config"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "config = create_n2v_configuration(\n",
109
+ " experiment_name=\"bsd68_n2v\",\n",
110
+ " data_type=\"tiff\",\n",
111
+ " axes=\"SYX\",\n",
112
+ " patch_size=(64, 64),\n",
113
+ " batch_size=64,\n",
114
+ " num_epochs=100,\n",
115
+ ")\n",
116
+ "\n",
117
+ "# TODO until the NGDataConfig is accepted by the Confiugration, these are separte\n",
118
+ "ng_data_config = _create_ng_data_configuration(\n",
119
+ " data_type=config.data_config.data_type,\n",
120
+ " axes=config.data_config.axes,\n",
121
+ " patch_size=config.data_config.patch_size,\n",
122
+ " batch_size=config.data_config.batch_size,\n",
123
+ " augmentations=config.data_config.transforms,\n",
124
+ " train_dataloader_params=config.data_config.train_dataloader_params,\n",
125
+ " val_dataloader_params=config.data_config.val_dataloader_params,\n",
126
+ " seed=seed,\n",
127
+ ")\n"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "### Create Lightning datamodule and model"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "train_data_module = CareamicsDataModule(\n",
144
+ " data_config=ng_data_config,\n",
145
+ " train_data=train_files,\n",
146
+ " val_data=val_files,\n",
147
+ ")\n",
148
+ "\n",
149
+ "model = N2VModule(config.algorithm_config)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "### Manually initialize the datamodule and visualize single train and val batches"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "train_data_module.setup(\"fit\")\n",
166
+ "train_data_module.setup(\"validate\")\n",
167
+ "\n",
168
+ "train_batch = next(iter(train_data_module.train_dataloader()))\n",
169
+ "val_batch = next(iter(train_data_module.val_dataloader()))\n",
170
+ "\n",
171
+ "fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
172
+ "ax[0].set_title(\"Training Batch\")\n",
173
+ "for i in range(8):\n",
174
+ " ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap=\"gray\")\n",
175
+ "\n",
176
+ "fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
177
+ "ax[0].set_title(\"Validation Batch\")\n",
178
+ "for i in range(8):\n",
179
+ " ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap=\"gray\")"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {},
185
+ "source": [
186
+ "### Train the model"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "from pytorch_lightning import Trainer\n",
196
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
197
+ "from pytorch_lightning.loggers import WandbLogger\n",
198
+ "\n",
199
+ "root = Path(\"bsd68_n2v\")\n",
200
+ "callbacks = [\n",
201
+ " ModelCheckpoint(\n",
202
+ " dirpath=root / \"checkpoints\",\n",
203
+ " filename=\"bsd68_new_lightning_module\",\n",
204
+ " save_last=True,\n",
205
+ " monitor=\"val_loss\",\n",
206
+ " mode=\"min\",\n",
207
+ " ),\n",
208
+ " HyperParametersCallback(config),\n",
209
+ "]\n",
210
+ "logger = WandbLogger(project=\"bsd68-n2v\", name=\"bsd68_new_lightning_module\")\n",
211
+ "\n",
212
+ "trainer = Trainer(\n",
213
+ " max_epochs=50, default_root_dir=root, callbacks=callbacks, logger=logger\n",
214
+ ")\n",
215
+ "trainer.fit(model, datamodule=train_data_module)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {},
221
+ "source": [
222
+ "### Create an inference config and datamodule"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
232
+ "from careamics.prediction_utils import convert_outputs\n",
233
+ "\n",
234
+ "config = NGDataConfig(\n",
235
+ " data_type=\"tiff\",\n",
236
+ " patching={\n",
237
+ " \"name\": \"tiled\",\n",
238
+ " \"patch_size\": (128, 128),\n",
239
+ " \"overlaps\": (32, 32),\n",
240
+ " },\n",
241
+ " axes=\"YX\",\n",
242
+ " batch_size=1,\n",
243
+ " image_means=train_data_module.train_dataset.input_stats.means,\n",
244
+ " image_stds=train_data_module.train_dataset.input_stats.stds,\n",
245
+ ")\n",
246
+ "\n",
247
+ "inf_data_module = CareamicsDataModule(data_config=config, pred_data=test_files)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "### Convert outputs to the legacy format and stitch the tiles"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "predictions = trainer.predict(model, datamodule=inf_data_module)\n",
264
+ "tile_infos = imageregions_to_tileinfos(predictions)\n",
265
+ "predictions = convert_outputs(tile_infos, tiled=True)"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "metadata": {},
271
+ "source": [
272
+ "### Visualize predictions and count metrics"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "from careamics.utils.metrics import psnr, scale_invariant_psnr\n",
282
+ "\n",
283
+ "noises = [tifffile.imread(f) for f in sorted(test_path.glob(\"*.tiff\"))]\n",
284
+ "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n",
285
+ "\n",
286
+ "images = [0, 1, 2]\n",
287
+ "fig, ax = plt.subplots(3, 3, figsize=(15, 15))\n",
288
+ "fig.tight_layout()\n",
289
+ "\n",
290
+ "for i in range(3):\n",
291
+ " pred_image = predictions[images[i]].squeeze()\n",
292
+ " psnr_noisy = psnr(\n",
293
+ " gts[images[i]],\n",
294
+ " noises[images[i]],\n",
295
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
296
+ " )\n",
297
+ " psnr_result = psnr(\n",
298
+ " gts[images[i]],\n",
299
+ " pred_image,\n",
300
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
301
+ " )\n",
302
+ "\n",
303
+ " scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)\n",
304
+ "\n",
305
+ " ax[i, 0].imshow(noises[images[i]], cmap=\"gray\")\n",
306
+ " ax[i, 0].title.set_text(f\"Noisy\\nPSNR: {psnr_noisy:.2f}\")\n",
307
+ "\n",
308
+ " ax[i, 1].imshow(pred_image, cmap=\"gray\")\n",
309
+ " ax[i, 1].title.set_text(\n",
310
+ " f\"Prediction\\nPSNR: {psnr_result:.2f}\\n\"\n",
311
+ " f\"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}\"\n",
312
+ " )\n",
313
+ "\n",
314
+ " ax[i, 2].imshow(gts[images[i]], cmap=\"gray\")\n",
315
+ " ax[i, 2].title.set_text(\"Ground-truth\")"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "psnrs = np.zeros((len(predictions), 1))\n",
325
+ "scale_invariant_psnrs = np.zeros((len(predictions), 1))\n",
326
+ "\n",
327
+ "for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):\n",
328
+ " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
329
+ " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
330
+ "\n",
331
+ "print(f\"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}\")\n",
332
+ "print(\n",
333
+ " f\"Scale invariant PSNR: \"\n",
334
+ " f\"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}\"\n",
335
+ ")\n",
336
+ "print(\"Reported PSNR: 27.71\")"
337
+ ]
338
+ }
339
+ ],
340
+ "metadata": {
341
+ "kernelspec": {
342
+ "display_name": "czi",
343
+ "language": "python",
344
+ "name": "python3"
345
+ },
346
+ "language_info": {
347
+ "codemirror_mode": {
348
+ "name": "ipython",
349
+ "version": 3
350
+ },
351
+ "file_extension": ".py",
352
+ "mimetype": "text/x-python",
353
+ "name": "python",
354
+ "nbconvert_exporter": "python",
355
+ "pygments_lexer": "ipython3",
356
+ "version": "3.12.11"
357
+ }
358
+ },
359
+ "nbformat": 4,
360
+ "nbformat_minor": 2
361
+ }