careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset/dataset_utils/running_stats.py +7 -3
  7. careamics/dataset_ng/README.md +212 -0
  8. careamics/dataset_ng/dataset.py +233 -0
  9. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  10. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  11. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  12. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  13. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  14. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  15. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  16. careamics/dataset_ng/factory.py +408 -0
  17. careamics/dataset_ng/legacy_interoperability.py +168 -0
  18. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  19. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  20. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  21. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  22. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  23. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  24. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  25. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  26. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  28. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  29. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  30. careamics/lightning/dataset_ng/data_module.py +488 -0
  31. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  32. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  33. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  34. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  35. careamics/lightning/lightning_module.py +3 -0
  36. careamics/lvae_training/dataset/__init__.py +8 -3
  37. careamics/lvae_training/dataset/config.py +3 -3
  38. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  39. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  40. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  41. careamics/lvae_training/dataset/types.py +3 -3
  42. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  43. careamics/lvae_training/eval_utils.py +93 -3
  44. careamics/transforms/compose.py +1 -0
  45. careamics/transforms/normalize.py +18 -7
  46. careamics/utils/lightning_utils.py +25 -11
  47. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  48. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
  49. careamics/dataset_ng/dataset/__init__.py +0 -3
  50. careamics/dataset_ng/dataset/dataset.py +0 -184
  51. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  52. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  53. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  54. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,356 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import numpy as np\n",
13
+ "import tifffile\n",
14
+ "from careamics_portfolio import PortfolioManager\n",
15
+ "\n",
16
+ "from careamics.config.configuration_factories import create_n2v_configuration\n",
17
+ "from careamics.config.support import SupportedTransform\n",
18
+ "from careamics.lightning.callbacks import HyperParametersCallback\n",
19
+ "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
20
+ "from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "# Set seeds for reproducibility\n",
30
+ "from pytorch_lightning import seed_everything\n",
31
+ "\n",
32
+ "seed_everything(42)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {},
38
+ "source": [
39
+ "### Load data and set paths to it"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "# instantiate data portfolio manage and download the data\n",
49
+ "root_path = Path(\"./data\")\n",
50
+ "\n",
51
+ "portfolio = PortfolioManager()\n",
52
+ "files = portfolio.denoising.N2V_BSD68.download(root_path)\n",
53
+ "\n",
54
+ "# create paths for the data\n",
55
+ "data_path = Path(root_path / \"denoising-N2V_BSD68.unzip/BSD68_reproducibility_data\")\n",
56
+ "train_path = data_path / \"train\"\n",
57
+ "val_path = data_path / \"val\"\n",
58
+ "test_path = data_path / \"test\" / \"images\"\n",
59
+ "gt_path = data_path / \"test\" / \"gt\"\n",
60
+ "\n",
61
+ "# list train, val and test files\n",
62
+ "train_files = sorted(train_path.rglob(\"*.tiff\"))\n",
63
+ "val_files = sorted(val_path.rglob(\"*.tiff\"))\n",
64
+ "test_files = sorted(test_path.rglob(\"*.tiff\"))"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "### Visualize a single train and val image"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# load training and validation image and show them side by side\n",
81
+ "single_train_image = tifffile.imread(train_files[0])[0]\n",
82
+ "single_val_image = tifffile.imread(val_files[0])[0]\n",
83
+ "\n",
84
+ "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
85
+ "ax[0].imshow(single_train_image, cmap=\"gray\")\n",
86
+ "ax[0].set_title(\"Training Image\")\n",
87
+ "ax[1].imshow(single_val_image, cmap=\"gray\")\n",
88
+ "ax[1].set_title(\"Validation Image\")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "### Create config"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "config = create_n2v_configuration(\n",
105
+ " experiment_name=\"bsd68_n2v\",\n",
106
+ " data_type=\"tiff\",\n",
107
+ " axes=\"SYX\",\n",
108
+ " patch_size=(64, 64),\n",
109
+ " batch_size=64,\n",
110
+ " num_epochs=100,\n",
111
+ ")\n",
112
+ "\n",
113
+ "# Ensuring that transforms are set\n",
114
+ "config.data_config.transforms = [\n",
115
+ " {\n",
116
+ " \"name\": SupportedTransform.XY_FLIP.value,\n",
117
+ " \"flip_x\": True,\n",
118
+ " \"flip_y\": True,\n",
119
+ " },\n",
120
+ " {\n",
121
+ " \"name\": SupportedTransform.XY_RANDOM_ROTATE90.value,\n",
122
+ " },\n",
123
+ "]"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "### Create Lightning datamodule and model"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "train_data_module = CareamicsDataModule(\n",
140
+ " data_config=config.data_config,\n",
141
+ " train_data=train_files,\n",
142
+ " val_data=val_files,\n",
143
+ ")\n",
144
+ "\n",
145
+ "model = N2VModule(config.algorithm_config)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "### Manually initialize the datamodule and visualize single train and val batches"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "train_data_module.setup(\"fit\")\n",
162
+ "train_data_module.setup(\"validate\")\n",
163
+ "\n",
164
+ "train_batch = next(iter(train_data_module.train_dataloader()))\n",
165
+ "val_batch = next(iter(train_data_module.val_dataloader()))\n",
166
+ "\n",
167
+ "fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
168
+ "ax[0].set_title(\"Training Batch\")\n",
169
+ "for i in range(8):\n",
170
+ " ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap=\"gray\")\n",
171
+ "\n",
172
+ "fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
173
+ "ax[0].set_title(\"Validation Batch\")\n",
174
+ "for i in range(8):\n",
175
+ " ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap=\"gray\")"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {},
181
+ "source": [
182
+ "### Train the model"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "from pytorch_lightning import Trainer\n",
192
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
193
+ "from pytorch_lightning.loggers import WandbLogger\n",
194
+ "\n",
195
+ "root = Path(\"bsd68_n2v\")\n",
196
+ "callbacks = [\n",
197
+ " ModelCheckpoint(\n",
198
+ " dirpath=root / \"checkpoints\",\n",
199
+ " filename=\"bsd68_new_lightning_module\",\n",
200
+ " save_last=True,\n",
201
+ " monitor=\"val_loss\",\n",
202
+ " mode=\"min\",\n",
203
+ " ),\n",
204
+ " HyperParametersCallback(config),\n",
205
+ "]\n",
206
+ "logger = WandbLogger(project=\"bsd68-n2v\", name=\"bsd68_new_lightning_module\")\n",
207
+ "\n",
208
+ "trainer = Trainer(\n",
209
+ " max_epochs=50, default_root_dir=root, callbacks=callbacks, logger=logger\n",
210
+ ")\n",
211
+ "trainer.fit(model, datamodule=train_data_module)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "metadata": {},
217
+ "source": [
218
+ "### Create an inference config and datamodule"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "from careamics.config.inference_model import InferenceConfig\n",
228
+ "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
229
+ "from careamics.prediction_utils import convert_outputs\n",
230
+ "\n",
231
+ "config = InferenceConfig(\n",
232
+ " model_config=config,\n",
233
+ " data_type=\"tiff\",\n",
234
+ " tile_size=(128, 128),\n",
235
+ " tile_overlap=(32, 32),\n",
236
+ " axes=\"YX\",\n",
237
+ " batch_size=1,\n",
238
+ " image_means=train_data_module.train_dataset.input_stats.means,\n",
239
+ " image_stds=train_data_module.train_dataset.input_stats.stds,\n",
240
+ ")\n",
241
+ "\n",
242
+ "inf_data_module = CareamicsDataModule(data_config=config, pred_data=test_files)"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {},
248
+ "source": [
249
+ "### Convert outputs to the legacy format and stitch the tiles"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "predictions = trainer.predict(model, datamodule=inf_data_module)\n",
259
+ "tile_infos = imageregions_to_tileinfos(predictions)\n",
260
+ "predictions = convert_outputs(tile_infos, tiled=True)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {},
266
+ "source": [
267
+ "### Visualize predictions and count metrics"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "from careamics.utils.metrics import psnr, scale_invariant_psnr\n",
277
+ "\n",
278
+ "noises = [tifffile.imread(f) for f in sorted(test_path.glob(\"*.tiff\"))]\n",
279
+ "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n",
280
+ "\n",
281
+ "images = [0, 1, 2]\n",
282
+ "fig, ax = plt.subplots(3, 3, figsize=(15, 15))\n",
283
+ "fig.tight_layout()\n",
284
+ "\n",
285
+ "for i in range(3):\n",
286
+ " pred_image = predictions[images[i]].squeeze()\n",
287
+ " psnr_noisy = psnr(\n",
288
+ " gts[images[i]],\n",
289
+ " noises[images[i]],\n",
290
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
291
+ " )\n",
292
+ " psnr_result = psnr(\n",
293
+ " gts[images[i]],\n",
294
+ " pred_image,\n",
295
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
296
+ " )\n",
297
+ "\n",
298
+ " scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)\n",
299
+ "\n",
300
+ " ax[i, 0].imshow(noises[images[i]], cmap=\"gray\")\n",
301
+ " ax[i, 0].title.set_text(f\"Noisy\\nPSNR: {psnr_noisy:.2f}\")\n",
302
+ "\n",
303
+ " ax[i, 1].imshow(pred_image, cmap=\"gray\")\n",
304
+ " ax[i, 1].title.set_text(\n",
305
+ " f\"Prediction\\nPSNR: {psnr_result:.2f}\\n\"\n",
306
+ " f\"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}\"\n",
307
+ " )\n",
308
+ "\n",
309
+ " ax[i, 2].imshow(gts[images[i]], cmap=\"gray\")\n",
310
+ " ax[i, 2].title.set_text(\"Ground-truth\")"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "psnrs = np.zeros((len(predictions), 1))\n",
320
+ "scale_invariant_psnrs = np.zeros((len(predictions), 1))\n",
321
+ "\n",
322
+ "for i, (pred, gt) in enumerate(zip(predictions, gts)):\n",
323
+ " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
324
+ " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
325
+ "\n",
326
+ "print(f\"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}\")\n",
327
+ "print(\n",
328
+ " f\"Scale invariant PSNR: \"\n",
329
+ " f\"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}\"\n",
330
+ ")\n",
331
+ "print(\"Reported PSNR: 27.71\")"
332
+ ]
333
+ }
334
+ ],
335
+ "metadata": {
336
+ "kernelspec": {
337
+ "display_name": "Python 3",
338
+ "language": "python",
339
+ "name": "python3"
340
+ },
341
+ "language_info": {
342
+ "codemirror_mode": {
343
+ "name": "ipython",
344
+ "version": 3
345
+ },
346
+ "file_extension": ".py",
347
+ "mimetype": "text/x-python",
348
+ "name": "python",
349
+ "nbconvert_exporter": "python",
350
+ "pygments_lexer": "ipython3",
351
+ "version": "3.10.14"
352
+ }
353
+ },
354
+ "nbformat": 4,
355
+ "nbformat_minor": 2
356
+ }
@@ -0,0 +1,330 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import numpy as np\n",
13
+ "import tifffile\n",
14
+ "from careamics_portfolio import PortfolioManager\n",
15
+ "\n",
16
+ "from careamics.config import create_care_configuration\n",
17
+ "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
18
+ "from careamics.lightning.callbacks import HyperParametersCallback\n",
19
+ "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
20
+ "from careamics.lightning.dataset_ng.lightning_modules import CAREModule\n",
21
+ "from careamics.prediction_utils import convert_outputs"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# Set seed for reproducibility\n",
31
+ "from pytorch_lightning import seed_everything\n",
32
+ "\n",
33
+ "seed_everything(42)"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Load data and set paths to it"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# instantiate data portfolio manager and download the data\n",
50
+ "root_path = Path(\"./data\")\n",
51
+ "\n",
52
+ "portfolio = PortfolioManager()\n",
53
+ "download = portfolio.denoising.CARE_U2OS.download(root_path)\n",
54
+ "\n",
55
+ "root_path = root_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n",
56
+ "train_path = root_path / \"train\" / \"low\"\n",
57
+ "target_path = root_path / \"train\" / \"GT\"\n",
58
+ "test_path = root_path / \"test\" / \"low\"\n",
59
+ "test_target_path = root_path / \"test\" / \"GT\""
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "### Create config"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "train_files = sorted(train_path.glob(\"*.tif\"))\n",
76
+ "train_target_files = sorted(target_path.glob(\"*.tif\"))\n",
77
+ "\n",
78
+ "config = create_care_configuration(\n",
79
+ " experiment_name=\"care_U20S\",\n",
80
+ " data_type=\"tiff\",\n",
81
+ " axes=\"YX\",\n",
82
+ " patch_size=(128, 128),\n",
83
+ " batch_size=32,\n",
84
+ " num_epochs=50,\n",
85
+ ")"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {},
91
+ "source": [
92
+ "### Create Lightning datamodule and model"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "train_data_module = CareamicsDataModule(\n",
102
+ " data_config=config.data_config,\n",
103
+ " train_data=train_path,\n",
104
+ " train_data_target=target_path,\n",
105
+ " val_data=test_path,\n",
106
+ " val_data_target=test_target_path,\n",
107
+ ")\n",
108
+ "\n",
109
+ "model = CAREModule(config.algorithm_config)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "metadata": {},
115
+ "source": [
116
+ "### Manually initialize the datamodule and visualize single train and val batches"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "train_data_module.setup(\"fit\")\n",
126
+ "train_data_module.setup(\"validate\")\n",
127
+ "\n",
128
+ "train_batch = next(iter(train_data_module.train_dataloader()))\n",
129
+ "val_batch = next(iter(train_data_module.val_dataloader()))\n",
130
+ "\n",
131
+ "fig, ax = plt.subplots(2, 8, figsize=(10, 3))\n",
132
+ "\n",
133
+ "ax[0][0].set_title(\"Train batch\")\n",
134
+ "ax[1][0].set_title(\"Train target\")\n",
135
+ "for i in range(8):\n",
136
+ " ax[0][i].imshow(train_batch[0].data[i][0].numpy(), cmap=\"gray\")\n",
137
+ " ax[1][i].imshow(train_batch[1].data[i][0].numpy(), cmap=\"gray\")\n",
138
+ "\n",
139
+ "\n",
140
+ "fig, ax = plt.subplots(2, 8, figsize=(10, 3))\n",
141
+ "ax[0][0].set_title(\"Val batch\")\n",
142
+ "ax[1][0].set_title(\"Val target\")\n",
143
+ "for i in range(8):\n",
144
+ " ax[0][i].imshow(val_batch[0].data[i][0].numpy(), cmap=\"gray\")\n",
145
+ " ax[1][i].imshow(val_batch[1].data[i][0].numpy(), cmap=\"gray\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "### Train the model"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "from pytorch_lightning import Trainer\n",
162
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
163
+ "from pytorch_lightning.loggers import WandbLogger\n",
164
+ "\n",
165
+ "root = Path(\"care_baseline\")\n",
166
+ "callbacks = [\n",
167
+ " ModelCheckpoint(\n",
168
+ " dirpath=root / \"checkpoints\",\n",
169
+ " filename=\"care_baseline\",\n",
170
+ " save_last=True,\n",
171
+ " monitor=\"val_loss\",\n",
172
+ " mode=\"min\",\n",
173
+ " ),\n",
174
+ " HyperParametersCallback(config),\n",
175
+ "]\n",
176
+ "\n",
177
+ "wandb_logger = WandbLogger(project=\"care-U2OS\", name=\"new-dataset\")\n",
178
+ "\n",
179
+ "trainer = Trainer(\n",
180
+ " max_epochs=50, default_root_dir=root, callbacks=callbacks, logger=wandb_logger\n",
181
+ ")\n",
182
+ "trainer.fit(model, datamodule=train_data_module)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "metadata": {},
188
+ "source": [
189
+ "### Create an inference config and datamodule"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "from careamics.config.inference_model import InferenceConfig\n",
199
+ "\n",
200
+ "config = InferenceConfig(\n",
201
+ " model_config=config,\n",
202
+ " data_type=\"tiff\",\n",
203
+ " tile_size=(128, 128),\n",
204
+ " tile_overlap=(32, 32),\n",
205
+ " axes=\"YX\",\n",
206
+ " batch_size=1,\n",
207
+ " image_means=train_data_module.train_dataset.input_stats.means,\n",
208
+ " image_stds=train_data_module.train_dataset.input_stats.stds,\n",
209
+ ")\n",
210
+ "\n",
211
+ "inf_data_module = CareamicsDataModule(\n",
212
+ " data_config=config, pred_data=test_path\n",
213
+ ")"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {},
219
+ "source": [
220
+ "### Convert outputs to the legacy format and stitch the tiles"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "predictions = trainer.predict(model, datamodule=inf_data_module)\n",
230
+ "tile_infos = imageregions_to_tileinfos(predictions)\n",
231
+ "prediction = convert_outputs(tile_infos, tiled=True)"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "markdown",
236
+ "metadata": {},
237
+ "source": [
238
+ "### Visualize predictions and count metrics"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "from careamics.utils.metrics import psnr, scale_invariant_psnr\n",
248
+ "\n",
249
+ "# Show two images\n",
250
+ "noises = [tifffile.imread(f) for f in sorted(test_path.glob(\"*.tif\"))]\n",
251
+ "gts = [tifffile.imread(f) for f in sorted(test_target_path.glob(\"*.tif\"))]\n",
252
+ "\n",
253
+ "# images to show\n",
254
+ "images = [0, 1, 2]\n",
255
+ "\n",
256
+ "fig, ax = plt.subplots(3, 3, figsize=(15, 15))\n",
257
+ "fig.tight_layout()\n",
258
+ "\n",
259
+ "for i in range(3):\n",
260
+ " pred_image = prediction[images[i]].squeeze()\n",
261
+ " psnr_noisy = psnr(\n",
262
+ " gts[images[i]],\n",
263
+ " noises[images[i]],\n",
264
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
265
+ " )\n",
266
+ " psnr_result = psnr(\n",
267
+ " gts[images[i]],\n",
268
+ " pred_image,\n",
269
+ " data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
270
+ " )\n",
271
+ "\n",
272
+ " scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)\n",
273
+ "\n",
274
+ " ax[i, 0].imshow(noises[images[i]], cmap=\"gray\")\n",
275
+ " ax[i, 0].title.set_text(f\"Noisy\\nPSNR: {psnr_noisy:.2f}\")\n",
276
+ "\n",
277
+ " ax[i, 1].imshow(pred_image, cmap=\"gray\")\n",
278
+ " ax[i, 1].title.set_text(\n",
279
+ " f\"Prediction\\nPSNR: {psnr_result:.2f}\\n\"\n",
280
+ " f\"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}\"\n",
281
+ " )\n",
282
+ "\n",
283
+ " ax[i, 2].imshow(gts[images[i]], cmap=\"gray\")\n",
284
+ " ax[i, 2].title.set_text(\"Ground-truth\")"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": [
293
+ "psnrs = np.zeros((len(prediction), 1))\n",
294
+ "scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
295
+ "\n",
296
+ "for i, (pred, gt) in enumerate(zip(prediction, gts)):\n",
297
+ " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
298
+ " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
299
+ "\n",
300
+ "print(f\"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}\")\n",
301
+ "print(\n",
302
+ " f\"Scale invariant PSNR: \"\n",
303
+ " f\"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}\"\n",
304
+ ")\n",
305
+ "print(\"Target PSNR: 31.53 +/- 3.71\")"
306
+ ]
307
+ }
308
+ ],
309
+ "metadata": {
310
+ "kernelspec": {
311
+ "display_name": "Python 3",
312
+ "language": "python",
313
+ "name": "python3"
314
+ },
315
+ "language_info": {
316
+ "codemirror_mode": {
317
+ "name": "ipython",
318
+ "version": 3
319
+ },
320
+ "file_extension": ".py",
321
+ "mimetype": "text/x-python",
322
+ "name": "python",
323
+ "nbconvert_exporter": "python",
324
+ "pygments_lexer": "ipython3",
325
+ "version": "3.9.20"
326
+ }
327
+ },
328
+ "nbformat": 4,
329
+ "nbformat_minor": 2
330
+ }