careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset/dataset_utils/running_stats.py +7 -3
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "code",
|
|
5
|
+
"execution_count": null,
|
|
6
|
+
"metadata": {},
|
|
7
|
+
"outputs": [],
|
|
8
|
+
"source": [
|
|
9
|
+
"from pathlib import Path\n",
|
|
10
|
+
"\n",
|
|
11
|
+
"import matplotlib.pyplot as plt\n",
|
|
12
|
+
"import numpy as np\n",
|
|
13
|
+
"from careamics_portfolio import PortfolioManager\n",
|
|
14
|
+
"\n",
|
|
15
|
+
"from careamics.config.configuration_factories import create_n2v_configuration\n",
|
|
16
|
+
"from careamics.config.support import SupportedTransform\n",
|
|
17
|
+
"from careamics.lightning.callbacks import HyperParametersCallback\n",
|
|
18
|
+
"from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
|
|
19
|
+
"from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
|
|
20
|
+
]
|
|
21
|
+
},
|
|
22
|
+
{
|
|
23
|
+
"cell_type": "code",
|
|
24
|
+
"execution_count": null,
|
|
25
|
+
"metadata": {},
|
|
26
|
+
"outputs": [],
|
|
27
|
+
"source": [
|
|
28
|
+
"portfolio = PortfolioManager()\n",
|
|
29
|
+
"files = portfolio.denoiseg.MouseNuclei_n20.download()\n",
|
|
30
|
+
"files.sort()\n",
|
|
31
|
+
"\n",
|
|
32
|
+
"# load images\n",
|
|
33
|
+
"train_data = np.load(files[1])[\"X_train\"]\n",
|
|
34
|
+
"print(f\"Train data shape: {train_data.shape}\")"
|
|
35
|
+
]
|
|
36
|
+
},
|
|
37
|
+
{
|
|
38
|
+
"cell_type": "code",
|
|
39
|
+
"execution_count": null,
|
|
40
|
+
"metadata": {},
|
|
41
|
+
"outputs": [],
|
|
42
|
+
"source": [
|
|
43
|
+
"indices = [34, 293, 571, 783]\n",
|
|
44
|
+
"\n",
|
|
45
|
+
"fig, ax = plt.subplots(2, 2, figsize=(8, 8))\n",
|
|
46
|
+
"ax[0, 0].imshow(train_data[indices[0]], cmap=\"gray\")\n",
|
|
47
|
+
"ax[0, 0].set_title(f\"Image {indices[0]}\")\n",
|
|
48
|
+
"ax[0, 0].set_xticks([])\n",
|
|
49
|
+
"ax[0, 0].set_yticks([])\n",
|
|
50
|
+
"\n",
|
|
51
|
+
"ax[0, 1].imshow(train_data[indices[1]], cmap=\"gray\")\n",
|
|
52
|
+
"ax[0, 1].set_title(f\"Image {indices[1]}\")\n",
|
|
53
|
+
"ax[0, 1].set_xticks([])\n",
|
|
54
|
+
"ax[0, 1].set_yticks([])\n",
|
|
55
|
+
"\n",
|
|
56
|
+
"ax[1, 0].imshow(train_data[indices[2]], cmap=\"gray\")\n",
|
|
57
|
+
"ax[1, 0].set_title(f\"Image {indices[2]}\")\n",
|
|
58
|
+
"ax[1, 0].set_xticks([])\n",
|
|
59
|
+
"ax[1, 0].set_yticks([])\n",
|
|
60
|
+
"\n",
|
|
61
|
+
"ax[1, 1].imshow(train_data[indices[3]], cmap=\"gray\")\n",
|
|
62
|
+
"ax[1, 1].set_title(f\"Image {indices[3]}\")\n",
|
|
63
|
+
"ax[1, 1].set_xticks([])\n",
|
|
64
|
+
"ax[1, 1].set_yticks([])\n",
|
|
65
|
+
"\n",
|
|
66
|
+
"plt.show()"
|
|
67
|
+
]
|
|
68
|
+
},
|
|
69
|
+
{
|
|
70
|
+
"cell_type": "code",
|
|
71
|
+
"execution_count": null,
|
|
72
|
+
"metadata": {},
|
|
73
|
+
"outputs": [],
|
|
74
|
+
"source": [
|
|
75
|
+
"config = create_n2v_configuration(\n",
|
|
76
|
+
" experiment_name=\"mouse_nuclei_n2v\",\n",
|
|
77
|
+
" data_type=\"array\",\n",
|
|
78
|
+
" axes=\"SYX\",\n",
|
|
79
|
+
" patch_size=(64, 64),\n",
|
|
80
|
+
" batch_size=16,\n",
|
|
81
|
+
" num_epochs=10,\n",
|
|
82
|
+
")\n",
|
|
83
|
+
"\n",
|
|
84
|
+
"print(config)"
|
|
85
|
+
]
|
|
86
|
+
},
|
|
87
|
+
{
|
|
88
|
+
"cell_type": "code",
|
|
89
|
+
"execution_count": null,
|
|
90
|
+
"metadata": {},
|
|
91
|
+
"outputs": [],
|
|
92
|
+
"source": [
|
|
93
|
+
"# Ensuring that transforms are set\n",
|
|
94
|
+
"config.data_config.transforms =[\n",
|
|
95
|
+
" {\n",
|
|
96
|
+
" \"name\": SupportedTransform.XY_FLIP.value,\n",
|
|
97
|
+
" \"flip_x\": True,\n",
|
|
98
|
+
" \"flip_y\": True,\n",
|
|
99
|
+
" },\n",
|
|
100
|
+
" {\n",
|
|
101
|
+
" \"name\": SupportedTransform.XY_RANDOM_ROTATE90.value,\n",
|
|
102
|
+
" },\n",
|
|
103
|
+
"]"
|
|
104
|
+
]
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"cell_type": "code",
|
|
108
|
+
"execution_count": null,
|
|
109
|
+
"metadata": {},
|
|
110
|
+
"outputs": [],
|
|
111
|
+
"source": [
|
|
112
|
+
"from sklearn.model_selection import train_test_split\n",
|
|
113
|
+
"\n",
|
|
114
|
+
"train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)\n"
|
|
115
|
+
]
|
|
116
|
+
},
|
|
117
|
+
{
|
|
118
|
+
"cell_type": "code",
|
|
119
|
+
"execution_count": null,
|
|
120
|
+
"metadata": {},
|
|
121
|
+
"outputs": [],
|
|
122
|
+
"source": [
|
|
123
|
+
"train_data_module = CareamicsDataModule(\n",
|
|
124
|
+
" data_config=config.data_config,\n",
|
|
125
|
+
" train_data=train_data,\n",
|
|
126
|
+
" val_data=val_data,\n",
|
|
127
|
+
")\n",
|
|
128
|
+
"\n",
|
|
129
|
+
"model = N2VModule(config.algorithm_config)"
|
|
130
|
+
]
|
|
131
|
+
},
|
|
132
|
+
{
|
|
133
|
+
"cell_type": "code",
|
|
134
|
+
"execution_count": null,
|
|
135
|
+
"metadata": {},
|
|
136
|
+
"outputs": [],
|
|
137
|
+
"source": [
|
|
138
|
+
"from pytorch_lightning import Trainer\n",
|
|
139
|
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
|
140
|
+
"from pytorch_lightning.loggers import WandbLogger\n",
|
|
141
|
+
"\n",
|
|
142
|
+
"root = Path(\"nuclei_n2v\")\n",
|
|
143
|
+
"callbacks = [\n",
|
|
144
|
+
" ModelCheckpoint(\n",
|
|
145
|
+
" dirpath=root / \"checkpoints\",\n",
|
|
146
|
+
" filename=\"nuclei_new_lightning_module\",\n",
|
|
147
|
+
" save_last=True,\n",
|
|
148
|
+
" monitor=\"val_loss\",\n",
|
|
149
|
+
" mode=\"min\",\n",
|
|
150
|
+
" ),\n",
|
|
151
|
+
" HyperParametersCallback(config)\n",
|
|
152
|
+
"]\n",
|
|
153
|
+
"logger = WandbLogger(\n",
|
|
154
|
+
" project=\"nuclei-n2v\", name=\"nuclei_new_lightning_module\"\n",
|
|
155
|
+
")\n",
|
|
156
|
+
"\n",
|
|
157
|
+
"trainer = Trainer(\n",
|
|
158
|
+
" max_epochs=10,\n",
|
|
159
|
+
" default_root_dir=root,\n",
|
|
160
|
+
" callbacks=callbacks,\n",
|
|
161
|
+
" logger=logger\n",
|
|
162
|
+
")\n",
|
|
163
|
+
"trainer.fit(model, datamodule=train_data_module)"
|
|
164
|
+
]
|
|
165
|
+
},
|
|
166
|
+
{
|
|
167
|
+
"cell_type": "code",
|
|
168
|
+
"execution_count": null,
|
|
169
|
+
"metadata": {},
|
|
170
|
+
"outputs": [],
|
|
171
|
+
"source": [
|
|
172
|
+
"from careamics.config.inference_model import InferenceConfig\n",
|
|
173
|
+
"from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
|
|
174
|
+
"from careamics.prediction_utils import convert_outputs\n",
|
|
175
|
+
"\n",
|
|
176
|
+
"train_data = np.load(files[1])[\"X_train\"]\n",
|
|
177
|
+
"\n",
|
|
178
|
+
"config = InferenceConfig(\n",
|
|
179
|
+
" model_config=config,\n",
|
|
180
|
+
" data_type=\"array\",\n",
|
|
181
|
+
" tile_size=(64, 64),\n",
|
|
182
|
+
" tile_overlap=(32, 32),\n",
|
|
183
|
+
" axes=\"SYX\",\n",
|
|
184
|
+
" batch_size=1,\n",
|
|
185
|
+
" image_means=train_data_module.train_dataset.input_stats.means,\n",
|
|
186
|
+
" image_stds=train_data_module.train_dataset.input_stats.stds\n",
|
|
187
|
+
")\n",
|
|
188
|
+
"\n",
|
|
189
|
+
"inf_data_module = CareamicsDataModule(\n",
|
|
190
|
+
" data_config=config,\n",
|
|
191
|
+
" pred_data=train_data\n",
|
|
192
|
+
")"
|
|
193
|
+
]
|
|
194
|
+
},
|
|
195
|
+
{
|
|
196
|
+
"cell_type": "code",
|
|
197
|
+
"execution_count": null,
|
|
198
|
+
"metadata": {},
|
|
199
|
+
"outputs": [],
|
|
200
|
+
"source": [
|
|
201
|
+
"predictions = trainer.predict(model, datamodule=inf_data_module)\n",
|
|
202
|
+
"tile_infos = imageregions_to_tileinfos(predictions)\n",
|
|
203
|
+
"predictions = convert_outputs(tile_infos, tiled=True)\n",
|
|
204
|
+
"predictions = np.stack(predictions).squeeze()"
|
|
205
|
+
]
|
|
206
|
+
},
|
|
207
|
+
{
|
|
208
|
+
"cell_type": "code",
|
|
209
|
+
"execution_count": null,
|
|
210
|
+
"metadata": {},
|
|
211
|
+
"outputs": [],
|
|
212
|
+
"source": [
|
|
213
|
+
"files = portfolio.denoiseg.MouseNuclei_n0.download()\n",
|
|
214
|
+
"files.sort()\n",
|
|
215
|
+
"\n",
|
|
216
|
+
"gt_data = np.load(files[1])[\"X_train\"]\n",
|
|
217
|
+
"print(f\"GT data shape: {gt_data.shape}\")\n",
|
|
218
|
+
"print(f\"Predictions shape: {predictions.shape}\")"
|
|
219
|
+
]
|
|
220
|
+
},
|
|
221
|
+
{
|
|
222
|
+
"cell_type": "code",
|
|
223
|
+
"execution_count": null,
|
|
224
|
+
"metadata": {},
|
|
225
|
+
"outputs": [],
|
|
226
|
+
"source": [
|
|
227
|
+
"from careamics.utils.metrics import scale_invariant_psnr\n",
|
|
228
|
+
"\n",
|
|
229
|
+
"indices = [389, 621]\n",
|
|
230
|
+
"\n",
|
|
231
|
+
"for i in indices:\n",
|
|
232
|
+
" # compute psnr\n",
|
|
233
|
+
" psnr_noisy = scale_invariant_psnr(gt_data[i], train_data[i])\n",
|
|
234
|
+
" psnr_denoised = scale_invariant_psnr(gt_data[i], predictions[i].squeeze())\n",
|
|
235
|
+
"\n",
|
|
236
|
+
" # plot images\n",
|
|
237
|
+
" fig, ax = plt.subplots(1, 3, figsize=(10, 10))\n",
|
|
238
|
+
" ax[0].imshow(train_data[i], cmap=\"gray\")\n",
|
|
239
|
+
" ax[0].set_title(f\"Noisy Image\\nPSNR: {psnr_noisy:.2f}\")\n",
|
|
240
|
+
" ax[0].set_xticks([])\n",
|
|
241
|
+
" ax[0].set_yticks([])\n",
|
|
242
|
+
"\n",
|
|
243
|
+
" ax[1].imshow(predictions[i].squeeze(), cmap=\"gray\")\n",
|
|
244
|
+
" ax[1].set_title(f\"Denoised Image\\nPSNR: {psnr_denoised:.2f}\")\n",
|
|
245
|
+
" ax[1].set_xticks([])\n",
|
|
246
|
+
" ax[1].set_yticks([])\n",
|
|
247
|
+
"\n",
|
|
248
|
+
" ax[2].imshow(gt_data[i], cmap=\"gray\")\n",
|
|
249
|
+
" ax[2].set_title(\"GT Image\")\n",
|
|
250
|
+
" ax[2].set_xticks([])\n",
|
|
251
|
+
" ax[2].set_yticks([])\n",
|
|
252
|
+
"\n",
|
|
253
|
+
" plt.show()"
|
|
254
|
+
]
|
|
255
|
+
},
|
|
256
|
+
{
|
|
257
|
+
"cell_type": "code",
|
|
258
|
+
"execution_count": null,
|
|
259
|
+
"metadata": {},
|
|
260
|
+
"outputs": [],
|
|
261
|
+
"source": [
|
|
262
|
+
"psnrs = np.zeros(gt_data.shape[0])\n",
|
|
263
|
+
"\n",
|
|
264
|
+
"for i in range(gt_data.shape[0]):\n",
|
|
265
|
+
" psnrs[i] = scale_invariant_psnr(gt_data[i], predictions[i].squeeze())\n",
|
|
266
|
+
"\n",
|
|
267
|
+
"print(f\"PSNR: {np.mean(psnrs):.2f} ± {np.std(psnrs):.2f}\")"
|
|
268
|
+
]
|
|
269
|
+
}
|
|
270
|
+
],
|
|
271
|
+
"metadata": {
|
|
272
|
+
"kernelspec": {
|
|
273
|
+
"display_name": "Python 3",
|
|
274
|
+
"language": "python",
|
|
275
|
+
"name": "python3"
|
|
276
|
+
},
|
|
277
|
+
"language_info": {
|
|
278
|
+
"codemirror_mode": {
|
|
279
|
+
"name": "ipython",
|
|
280
|
+
"version": 3
|
|
281
|
+
},
|
|
282
|
+
"file_extension": ".py",
|
|
283
|
+
"mimetype": "text/x-python",
|
|
284
|
+
"name": "python",
|
|
285
|
+
"nbconvert_exporter": "python",
|
|
286
|
+
"pygments_lexer": "ipython3",
|
|
287
|
+
"version": "3.9.20"
|
|
288
|
+
}
|
|
289
|
+
},
|
|
290
|
+
"nbformat": 4,
|
|
291
|
+
"nbformat_minor": 2
|
|
292
|
+
}
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from typing_extensions import ParamSpec
|
|
8
|
+
|
|
9
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
10
|
+
from careamics.config.support import SupportedData
|
|
11
|
+
from careamics.dataset_ng.patch_extractor import ImageStackLoader, PatchExtractor
|
|
12
|
+
from careamics.dataset_ng.patch_extractor.image_stack import (
|
|
13
|
+
GenericImageStack,
|
|
14
|
+
ImageStack,
|
|
15
|
+
InMemoryImageStack,
|
|
16
|
+
ZarrImageStack,
|
|
17
|
+
)
|
|
18
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
19
|
+
create_array_extractor,
|
|
20
|
+
create_custom_file_extractor,
|
|
21
|
+
create_custom_image_stack_extractor,
|
|
22
|
+
create_ome_zarr_extractor,
|
|
23
|
+
create_tiff_extractor,
|
|
24
|
+
)
|
|
25
|
+
from careamics.file_io.read import ReadFunc
|
|
26
|
+
|
|
27
|
+
from .dataset import CareamicsDataset, Mode
|
|
28
|
+
|
|
29
|
+
P = ParamSpec("P")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Enum class used to determine which loading functions should be used
|
|
33
|
+
class DatasetType(Enum):
|
|
34
|
+
"""Labels for the dataset based on the underlying data and how it is loaded."""
|
|
35
|
+
|
|
36
|
+
ARRAY = "array"
|
|
37
|
+
IN_MEM_TIFF = "in_mem_tiff"
|
|
38
|
+
LAZY_TIFF = "lazy_tiff"
|
|
39
|
+
IN_MEM_CUSTOM_FILE = "in_mem_custom_file"
|
|
40
|
+
OME_ZARR = "ome_zarr"
|
|
41
|
+
CUSTOM_IMAGE_STACK = "custom_image_stack"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# bit of a mess of if-else statements
|
|
45
|
+
def determine_dataset_type(
|
|
46
|
+
data_type: SupportedData,
|
|
47
|
+
in_memory: bool,
|
|
48
|
+
read_func: Optional[ReadFunc] = None,
|
|
49
|
+
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
50
|
+
) -> DatasetType:
|
|
51
|
+
"""Determine what the dataset type should be based on the input arguments.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
data_type : SupportedData
|
|
56
|
+
The underlying datatype.
|
|
57
|
+
in_memory : bool
|
|
58
|
+
Whether all the data should be loaded into memory. This is argument is ignored
|
|
59
|
+
unless the `data_type` is "tiff" or "custom".
|
|
60
|
+
read_func : ReadFunc, optional
|
|
61
|
+
A function that can that can be used to load custom data. This argument is
|
|
62
|
+
ignored unless the `data_type` is "custom".
|
|
63
|
+
image_stack_loader : ImageStackLoader, optional
|
|
64
|
+
A function for custom image stack loading. This argument is ignored unless the
|
|
65
|
+
`data_type` is "custom".
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
DatasetType
|
|
70
|
+
The Dataset type.
|
|
71
|
+
|
|
72
|
+
Raises
|
|
73
|
+
------
|
|
74
|
+
NotImplementedError
|
|
75
|
+
For lazy-loading (`in_memory=False`) of a custom file type.
|
|
76
|
+
ValueError
|
|
77
|
+
If the `data_type` is "custom" but both `read_func` and `image_stack_loader` are
|
|
78
|
+
None.
|
|
79
|
+
ValueError
|
|
80
|
+
If the `data_type` is unrecognized.
|
|
81
|
+
"""
|
|
82
|
+
if data_type == SupportedData.ARRAY:
|
|
83
|
+
# TODO: ignoring in_memory arg, error if False?
|
|
84
|
+
return DatasetType.ARRAY
|
|
85
|
+
elif data_type == SupportedData.TIFF:
|
|
86
|
+
if in_memory:
|
|
87
|
+
return DatasetType.IN_MEM_TIFF
|
|
88
|
+
else:
|
|
89
|
+
return DatasetType.LAZY_TIFF
|
|
90
|
+
elif data_type == SupportedData.CUSTOM:
|
|
91
|
+
if read_func is not None:
|
|
92
|
+
if in_memory:
|
|
93
|
+
return DatasetType.IN_MEM_CUSTOM_FILE
|
|
94
|
+
else:
|
|
95
|
+
raise NotImplementedError(
|
|
96
|
+
"Lazy loading has not been implemented for custom file types yet."
|
|
97
|
+
)
|
|
98
|
+
elif image_stack_loader is not None:
|
|
99
|
+
# TODO: ignoring im_memory arg
|
|
100
|
+
return DatasetType.CUSTOM_IMAGE_STACK
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"Found `data_type='custom'` but no `read_func` or `image_stack_loader` "
|
|
104
|
+
"has been provided."
|
|
105
|
+
)
|
|
106
|
+
# TODO: ZARR
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unrecognized `data_type`, '{data_type}'.")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# convenience function but should use `create_dataloader` function instead
|
|
112
|
+
# For lazy loading custom batch sampler also needs to be set.
|
|
113
|
+
def create_dataset(
|
|
114
|
+
config: Union[DataConfig, InferenceConfig],
|
|
115
|
+
mode: Mode,
|
|
116
|
+
inputs: Any,
|
|
117
|
+
targets: Any,
|
|
118
|
+
in_memory: bool,
|
|
119
|
+
read_func: Optional[ReadFunc] = None,
|
|
120
|
+
read_kwargs: Optional[dict[str, Any]] = None,
|
|
121
|
+
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
122
|
+
image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
|
|
123
|
+
) -> CareamicsDataset[ImageStack]:
|
|
124
|
+
"""
|
|
125
|
+
Convenience function to create the CAREamicsDataset.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
config : DataConfig or InferenceConfig
|
|
130
|
+
The data configuration.
|
|
131
|
+
mode : Mode
|
|
132
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
133
|
+
inputs : Any
|
|
134
|
+
The input sources to the dataset.
|
|
135
|
+
targets : Any, optional
|
|
136
|
+
The target sources to the dataset.
|
|
137
|
+
in_memory : bool
|
|
138
|
+
Whether all the data should be loaded into memory. This is argument is ignored
|
|
139
|
+
unless the `data_type` in `config` is "tiff" or "custom".
|
|
140
|
+
read_func : ReadFunc, optional
|
|
141
|
+
A function that can that can be used to load custom data. This argument is
|
|
142
|
+
ignored unless the `data_type` in the `config` is "custom".
|
|
143
|
+
read_kwargs : dict of {str, Any}, optional
|
|
144
|
+
Additional key-word arguments to pass to the `read_func`.
|
|
145
|
+
image_stack_loader : ImageStackLoader, optional
|
|
146
|
+
A function for custom image stack loading. This argument is ignored unless the
|
|
147
|
+
`data_type` in the `config` is "custom".
|
|
148
|
+
image_stack_loader_kwargs : {str, Any}, optional
|
|
149
|
+
Additional key-word arguments to pass to the `image_stack_loader`.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
CareamicsDataset[ImageStack]
|
|
154
|
+
The CAREamicsDataset
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
ValueError
|
|
159
|
+
For an unrecognized `data_type` in the `config`.
|
|
160
|
+
"""
|
|
161
|
+
data_type = SupportedData(config.data_type)
|
|
162
|
+
dataset_type = determine_dataset_type(
|
|
163
|
+
data_type, in_memory, read_func, image_stack_loader
|
|
164
|
+
)
|
|
165
|
+
if dataset_type == DatasetType.ARRAY:
|
|
166
|
+
return create_array_dataset(config, mode, inputs, targets)
|
|
167
|
+
elif dataset_type == DatasetType.IN_MEM_TIFF:
|
|
168
|
+
return create_tiff_dataset(config, mode, inputs, targets)
|
|
169
|
+
# TODO: Lazy tiff
|
|
170
|
+
elif dataset_type == DatasetType.IN_MEM_CUSTOM_FILE:
|
|
171
|
+
if read_kwargs is None:
|
|
172
|
+
read_kwargs = {}
|
|
173
|
+
assert read_func is not None # should be true from `determine_dataset_type`
|
|
174
|
+
return create_custom_file_dataset(
|
|
175
|
+
config, mode, inputs, targets, read_func=read_func, read_kwargs=read_kwargs
|
|
176
|
+
)
|
|
177
|
+
elif dataset_type == DatasetType.CUSTOM_IMAGE_STACK:
|
|
178
|
+
if image_stack_loader_kwargs is None:
|
|
179
|
+
image_stack_loader_kwargs = {}
|
|
180
|
+
assert image_stack_loader is not None # should be true
|
|
181
|
+
return create_custom_image_stack_dataset(
|
|
182
|
+
config,
|
|
183
|
+
mode,
|
|
184
|
+
inputs,
|
|
185
|
+
targets,
|
|
186
|
+
image_stack_loader,
|
|
187
|
+
**image_stack_loader_kwargs,
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f"Unrecognized dataset type, {dataset_type}.")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def create_array_dataset(
|
|
194
|
+
config: Union[DataConfig, InferenceConfig],
|
|
195
|
+
mode: Mode,
|
|
196
|
+
inputs: Sequence[NDArray[Any]],
|
|
197
|
+
targets: Optional[Sequence[NDArray[Any]]],
|
|
198
|
+
) -> CareamicsDataset[InMemoryImageStack]:
|
|
199
|
+
"""
|
|
200
|
+
Create a CAREamicsDataset from array data.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
config : DataConfig or InferenceConfig
|
|
205
|
+
The data configuration.
|
|
206
|
+
mode : Mode
|
|
207
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
208
|
+
inputs : Any
|
|
209
|
+
The input sources to the dataset.
|
|
210
|
+
targets : Any, optional
|
|
211
|
+
The target sources to the dataset.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
CareamicsDataset[InMemoryImageStack]
|
|
216
|
+
A CAREamicsDataset
|
|
217
|
+
"""
|
|
218
|
+
input_extractor = create_array_extractor(source=inputs, axes=config.axes)
|
|
219
|
+
target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
|
|
220
|
+
if targets is not None:
|
|
221
|
+
target_extractor = create_array_extractor(source=targets, axes=config.axes)
|
|
222
|
+
else:
|
|
223
|
+
target_extractor = None
|
|
224
|
+
return CareamicsDataset(config, mode, input_extractor, target_extractor)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def create_tiff_dataset(
|
|
228
|
+
config: Union[DataConfig, InferenceConfig],
|
|
229
|
+
mode: Mode,
|
|
230
|
+
inputs: Sequence[Path],
|
|
231
|
+
targets: Optional[Sequence[Path]],
|
|
232
|
+
) -> CareamicsDataset[InMemoryImageStack]:
|
|
233
|
+
"""
|
|
234
|
+
Create a CAREamicsDataset from tiff files that will be all loaded into memory.
|
|
235
|
+
|
|
236
|
+
Parameters
|
|
237
|
+
----------
|
|
238
|
+
config : DataConfig or InferenceConfig
|
|
239
|
+
The data configuration.
|
|
240
|
+
mode : Mode
|
|
241
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
242
|
+
inputs : Any
|
|
243
|
+
The input sources to the dataset.
|
|
244
|
+
targets : Any, optional
|
|
245
|
+
The target sources to the dataset.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
CareamicsDataset[InMemoryImageStack]
|
|
250
|
+
A CAREamicsDataset
|
|
251
|
+
"""
|
|
252
|
+
input_extractor = create_tiff_extractor(
|
|
253
|
+
source=inputs,
|
|
254
|
+
axes=config.axes,
|
|
255
|
+
)
|
|
256
|
+
target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
|
|
257
|
+
if targets is not None:
|
|
258
|
+
target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
|
|
259
|
+
else:
|
|
260
|
+
target_extractor = None
|
|
261
|
+
dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
|
|
262
|
+
return dataset
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def create_ome_zarr_dataset(
|
|
266
|
+
config: Union[DataConfig, InferenceConfig],
|
|
267
|
+
mode: Mode,
|
|
268
|
+
inputs: Sequence[Path],
|
|
269
|
+
targets: Optional[Sequence[Path]],
|
|
270
|
+
) -> CareamicsDataset[ZarrImageStack]:
|
|
271
|
+
"""
|
|
272
|
+
Create a dataset from OME ZARR files.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
config : DataConfig or InferenceConfig
|
|
277
|
+
The data configuration.
|
|
278
|
+
mode : Mode
|
|
279
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
280
|
+
inputs : Any
|
|
281
|
+
The input sources to the dataset.
|
|
282
|
+
targets : Any, optional
|
|
283
|
+
The target sources to the dataset.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
CareamicsDataset[ZarrImageStack]
|
|
288
|
+
A CAREamicsDataset
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
|
|
292
|
+
target_extractor: Optional[PatchExtractor[ZarrImageStack]]
|
|
293
|
+
if targets is not None:
|
|
294
|
+
target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
|
|
295
|
+
else:
|
|
296
|
+
target_extractor = None
|
|
297
|
+
dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
|
|
298
|
+
return dataset
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def create_custom_file_dataset(
|
|
302
|
+
config: Union[DataConfig, InferenceConfig],
|
|
303
|
+
mode: Mode,
|
|
304
|
+
inputs: Sequence[Path],
|
|
305
|
+
targets: Optional[Sequence[Path]],
|
|
306
|
+
*,
|
|
307
|
+
read_func: ReadFunc,
|
|
308
|
+
read_kwargs: dict[str, Any],
|
|
309
|
+
) -> CareamicsDataset[InMemoryImageStack]:
|
|
310
|
+
"""
|
|
311
|
+
Create a CAREamicsDataset from custom files that will be all loaded into memory.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
config : DataConfig or InferenceConfig
|
|
316
|
+
The data configuration.
|
|
317
|
+
mode : Mode
|
|
318
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
319
|
+
inputs : Any
|
|
320
|
+
The input sources to the dataset.
|
|
321
|
+
targets : Any, optional
|
|
322
|
+
The target sources to the dataset.
|
|
323
|
+
read_func : Optional[ReadFunc], optional
|
|
324
|
+
A function that can that can be used to load custom data. This argument is
|
|
325
|
+
ignored unless the `data_type` is "custom".
|
|
326
|
+
image_stack_loader : Optional[ImageStackLoader], optional
|
|
327
|
+
A function for custom image stack loading. This argument is ignored unless the
|
|
328
|
+
`data_type` is "custom".
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
CareamicsDataset[InMemoryImageStack]
|
|
333
|
+
A CAREamicsDataset
|
|
334
|
+
"""
|
|
335
|
+
input_extractor = create_custom_file_extractor(
|
|
336
|
+
source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
|
|
337
|
+
)
|
|
338
|
+
target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
|
|
339
|
+
if targets is not None:
|
|
340
|
+
target_extractor = create_custom_file_extractor(
|
|
341
|
+
source=targets,
|
|
342
|
+
axes=config.axes,
|
|
343
|
+
read_func=read_func,
|
|
344
|
+
read_kwargs=read_kwargs,
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
target_extractor = None
|
|
348
|
+
dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
|
|
349
|
+
return dataset
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def create_custom_image_stack_dataset(
|
|
353
|
+
config: Union[DataConfig, InferenceConfig],
|
|
354
|
+
mode: Mode,
|
|
355
|
+
inputs: Any,
|
|
356
|
+
targets: Optional[Any],
|
|
357
|
+
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
358
|
+
*args: P.args,
|
|
359
|
+
**kwargs: P.kwargs,
|
|
360
|
+
) -> CareamicsDataset[GenericImageStack]:
|
|
361
|
+
"""
|
|
362
|
+
Create a CAREamicsDataset from a custom `ImageStack` class.
|
|
363
|
+
|
|
364
|
+
The custom `ImageStack` class can be loaded using the `image_stack_loader` function.
|
|
365
|
+
|
|
366
|
+
Parameters
|
|
367
|
+
----------
|
|
368
|
+
config : DataConfig or InferenceConfig
|
|
369
|
+
The data configuration.
|
|
370
|
+
mode : Mode
|
|
371
|
+
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
372
|
+
inputs : Any
|
|
373
|
+
The input sources to the dataset.
|
|
374
|
+
targets : Any, optional
|
|
375
|
+
The target sources to the dataset.
|
|
376
|
+
image_stack_loader : ImageStackLoader
|
|
377
|
+
A function for custom image stack loading. This argument is ignored unless the
|
|
378
|
+
`data_type` is "custom".
|
|
379
|
+
*args : Any
|
|
380
|
+
Positional arguments to pass to the `image_stack_loader`.
|
|
381
|
+
**kwargs : Any
|
|
382
|
+
Key-word arguments to pass to the `image_stack_loader`.
|
|
383
|
+
|
|
384
|
+
Returns
|
|
385
|
+
-------
|
|
386
|
+
CareamicsDataset[GenericImageStack]
|
|
387
|
+
A CAREamicsDataset
|
|
388
|
+
"""
|
|
389
|
+
input_extractor = create_custom_image_stack_extractor(
|
|
390
|
+
inputs,
|
|
391
|
+
config.axes,
|
|
392
|
+
image_stack_loader,
|
|
393
|
+
*args,
|
|
394
|
+
**kwargs,
|
|
395
|
+
)
|
|
396
|
+
target_extractor: Optional[PatchExtractor[GenericImageStack]]
|
|
397
|
+
if targets is not None:
|
|
398
|
+
target_extractor = create_custom_image_stack_extractor(
|
|
399
|
+
targets,
|
|
400
|
+
config.axes,
|
|
401
|
+
image_stack_loader,
|
|
402
|
+
*args,
|
|
403
|
+
**kwargs,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
target_extractor = None
|
|
407
|
+
dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
|
|
408
|
+
return dataset
|