careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset/dataset_utils/running_stats.py +7 -3
  7. careamics/dataset_ng/README.md +212 -0
  8. careamics/dataset_ng/dataset.py +233 -0
  9. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  10. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  11. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  12. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  13. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  14. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  15. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  16. careamics/dataset_ng/factory.py +408 -0
  17. careamics/dataset_ng/legacy_interoperability.py +168 -0
  18. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  19. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  20. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  21. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  22. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  23. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  24. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  25. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  26. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  28. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  29. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  30. careamics/lightning/dataset_ng/data_module.py +488 -0
  31. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  32. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  33. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  34. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  35. careamics/lightning/lightning_module.py +3 -0
  36. careamics/lvae_training/dataset/__init__.py +8 -3
  37. careamics/lvae_training/dataset/config.py +3 -3
  38. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  39. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  40. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  41. careamics/lvae_training/dataset/types.py +3 -3
  42. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  43. careamics/lvae_training/eval_utils.py +93 -3
  44. careamics/transforms/compose.py +1 -0
  45. careamics/transforms/normalize.py +18 -7
  46. careamics/utils/lightning_utils.py +25 -11
  47. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  48. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
  49. careamics/dataset_ng/dataset/__init__.py +0 -3
  50. careamics/dataset_ng/dataset/dataset.py +0 -184
  51. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  52. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  53. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  54. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,292 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import numpy as np\n",
13
+ "from careamics_portfolio import PortfolioManager\n",
14
+ "\n",
15
+ "from careamics.config.configuration_factories import create_n2v_configuration\n",
16
+ "from careamics.config.support import SupportedTransform\n",
17
+ "from careamics.lightning.callbacks import HyperParametersCallback\n",
18
+ "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
19
+ "from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "portfolio = PortfolioManager()\n",
29
+ "files = portfolio.denoiseg.MouseNuclei_n20.download()\n",
30
+ "files.sort()\n",
31
+ "\n",
32
+ "# load images\n",
33
+ "train_data = np.load(files[1])[\"X_train\"]\n",
34
+ "print(f\"Train data shape: {train_data.shape}\")"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "indices = [34, 293, 571, 783]\n",
44
+ "\n",
45
+ "fig, ax = plt.subplots(2, 2, figsize=(8, 8))\n",
46
+ "ax[0, 0].imshow(train_data[indices[0]], cmap=\"gray\")\n",
47
+ "ax[0, 0].set_title(f\"Image {indices[0]}\")\n",
48
+ "ax[0, 0].set_xticks([])\n",
49
+ "ax[0, 0].set_yticks([])\n",
50
+ "\n",
51
+ "ax[0, 1].imshow(train_data[indices[1]], cmap=\"gray\")\n",
52
+ "ax[0, 1].set_title(f\"Image {indices[1]}\")\n",
53
+ "ax[0, 1].set_xticks([])\n",
54
+ "ax[0, 1].set_yticks([])\n",
55
+ "\n",
56
+ "ax[1, 0].imshow(train_data[indices[2]], cmap=\"gray\")\n",
57
+ "ax[1, 0].set_title(f\"Image {indices[2]}\")\n",
58
+ "ax[1, 0].set_xticks([])\n",
59
+ "ax[1, 0].set_yticks([])\n",
60
+ "\n",
61
+ "ax[1, 1].imshow(train_data[indices[3]], cmap=\"gray\")\n",
62
+ "ax[1, 1].set_title(f\"Image {indices[3]}\")\n",
63
+ "ax[1, 1].set_xticks([])\n",
64
+ "ax[1, 1].set_yticks([])\n",
65
+ "\n",
66
+ "plt.show()"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "config = create_n2v_configuration(\n",
76
+ " experiment_name=\"mouse_nuclei_n2v\",\n",
77
+ " data_type=\"array\",\n",
78
+ " axes=\"SYX\",\n",
79
+ " patch_size=(64, 64),\n",
80
+ " batch_size=16,\n",
81
+ " num_epochs=10,\n",
82
+ ")\n",
83
+ "\n",
84
+ "print(config)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# Ensuring that transforms are set\n",
94
+ "config.data_config.transforms =[\n",
95
+ " {\n",
96
+ " \"name\": SupportedTransform.XY_FLIP.value,\n",
97
+ " \"flip_x\": True,\n",
98
+ " \"flip_y\": True,\n",
99
+ " },\n",
100
+ " {\n",
101
+ " \"name\": SupportedTransform.XY_RANDOM_ROTATE90.value,\n",
102
+ " },\n",
103
+ "]"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "from sklearn.model_selection import train_test_split\n",
113
+ "\n",
114
+ "train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)\n"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "train_data_module = CareamicsDataModule(\n",
124
+ " data_config=config.data_config,\n",
125
+ " train_data=train_data,\n",
126
+ " val_data=val_data,\n",
127
+ ")\n",
128
+ "\n",
129
+ "model = N2VModule(config.algorithm_config)"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "from pytorch_lightning import Trainer\n",
139
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
140
+ "from pytorch_lightning.loggers import WandbLogger\n",
141
+ "\n",
142
+ "root = Path(\"nuclei_n2v\")\n",
143
+ "callbacks = [\n",
144
+ " ModelCheckpoint(\n",
145
+ " dirpath=root / \"checkpoints\",\n",
146
+ " filename=\"nuclei_new_lightning_module\",\n",
147
+ " save_last=True,\n",
148
+ " monitor=\"val_loss\",\n",
149
+ " mode=\"min\",\n",
150
+ " ),\n",
151
+ " HyperParametersCallback(config)\n",
152
+ "]\n",
153
+ "logger = WandbLogger(\n",
154
+ " project=\"nuclei-n2v\", name=\"nuclei_new_lightning_module\"\n",
155
+ ")\n",
156
+ "\n",
157
+ "trainer = Trainer(\n",
158
+ " max_epochs=10,\n",
159
+ " default_root_dir=root,\n",
160
+ " callbacks=callbacks,\n",
161
+ " logger=logger\n",
162
+ ")\n",
163
+ "trainer.fit(model, datamodule=train_data_module)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "from careamics.config.inference_model import InferenceConfig\n",
173
+ "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
174
+ "from careamics.prediction_utils import convert_outputs\n",
175
+ "\n",
176
+ "train_data = np.load(files[1])[\"X_train\"]\n",
177
+ "\n",
178
+ "config = InferenceConfig(\n",
179
+ " model_config=config,\n",
180
+ " data_type=\"array\",\n",
181
+ " tile_size=(64, 64),\n",
182
+ " tile_overlap=(32, 32),\n",
183
+ " axes=\"SYX\",\n",
184
+ " batch_size=1,\n",
185
+ " image_means=train_data_module.train_dataset.input_stats.means,\n",
186
+ " image_stds=train_data_module.train_dataset.input_stats.stds\n",
187
+ ")\n",
188
+ "\n",
189
+ "inf_data_module = CareamicsDataModule(\n",
190
+ " data_config=config,\n",
191
+ " pred_data=train_data\n",
192
+ ")"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "predictions = trainer.predict(model, datamodule=inf_data_module)\n",
202
+ "tile_infos = imageregions_to_tileinfos(predictions)\n",
203
+ "predictions = convert_outputs(tile_infos, tiled=True)\n",
204
+ "predictions = np.stack(predictions).squeeze()"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "files = portfolio.denoiseg.MouseNuclei_n0.download()\n",
214
+ "files.sort()\n",
215
+ "\n",
216
+ "gt_data = np.load(files[1])[\"X_train\"]\n",
217
+ "print(f\"GT data shape: {gt_data.shape}\")\n",
218
+ "print(f\"Predictions shape: {predictions.shape}\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "from careamics.utils.metrics import scale_invariant_psnr\n",
228
+ "\n",
229
+ "indices = [389, 621]\n",
230
+ "\n",
231
+ "for i in indices:\n",
232
+ " # compute psnr\n",
233
+ " psnr_noisy = scale_invariant_psnr(gt_data[i], train_data[i])\n",
234
+ " psnr_denoised = scale_invariant_psnr(gt_data[i], predictions[i].squeeze())\n",
235
+ "\n",
236
+ " # plot images\n",
237
+ " fig, ax = plt.subplots(1, 3, figsize=(10, 10))\n",
238
+ " ax[0].imshow(train_data[i], cmap=\"gray\")\n",
239
+ " ax[0].set_title(f\"Noisy Image\\nPSNR: {psnr_noisy:.2f}\")\n",
240
+ " ax[0].set_xticks([])\n",
241
+ " ax[0].set_yticks([])\n",
242
+ "\n",
243
+ " ax[1].imshow(predictions[i].squeeze(), cmap=\"gray\")\n",
244
+ " ax[1].set_title(f\"Denoised Image\\nPSNR: {psnr_denoised:.2f}\")\n",
245
+ " ax[1].set_xticks([])\n",
246
+ " ax[1].set_yticks([])\n",
247
+ "\n",
248
+ " ax[2].imshow(gt_data[i], cmap=\"gray\")\n",
249
+ " ax[2].set_title(\"GT Image\")\n",
250
+ " ax[2].set_xticks([])\n",
251
+ " ax[2].set_yticks([])\n",
252
+ "\n",
253
+ " plt.show()"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "psnrs = np.zeros(gt_data.shape[0])\n",
263
+ "\n",
264
+ "for i in range(gt_data.shape[0]):\n",
265
+ " psnrs[i] = scale_invariant_psnr(gt_data[i], predictions[i].squeeze())\n",
266
+ "\n",
267
+ "print(f\"PSNR: {np.mean(psnrs):.2f} ± {np.std(psnrs):.2f}\")"
268
+ ]
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "Python 3",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.9.20"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 2
292
+ }
@@ -0,0 +1,408 @@
1
+ from collections.abc import Sequence
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ from typing import Any, Optional, Union
5
+
6
+ from numpy.typing import NDArray
7
+ from typing_extensions import ParamSpec
8
+
9
+ from careamics.config import DataConfig, InferenceConfig
10
+ from careamics.config.support import SupportedData
11
+ from careamics.dataset_ng.patch_extractor import ImageStackLoader, PatchExtractor
12
+ from careamics.dataset_ng.patch_extractor.image_stack import (
13
+ GenericImageStack,
14
+ ImageStack,
15
+ InMemoryImageStack,
16
+ ZarrImageStack,
17
+ )
18
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
19
+ create_array_extractor,
20
+ create_custom_file_extractor,
21
+ create_custom_image_stack_extractor,
22
+ create_ome_zarr_extractor,
23
+ create_tiff_extractor,
24
+ )
25
+ from careamics.file_io.read import ReadFunc
26
+
27
+ from .dataset import CareamicsDataset, Mode
28
+
29
+ P = ParamSpec("P")
30
+
31
+
32
+ # Enum class used to determine which loading functions should be used
33
+ class DatasetType(Enum):
34
+ """Labels for the dataset based on the underlying data and how it is loaded."""
35
+
36
+ ARRAY = "array"
37
+ IN_MEM_TIFF = "in_mem_tiff"
38
+ LAZY_TIFF = "lazy_tiff"
39
+ IN_MEM_CUSTOM_FILE = "in_mem_custom_file"
40
+ OME_ZARR = "ome_zarr"
41
+ CUSTOM_IMAGE_STACK = "custom_image_stack"
42
+
43
+
44
+ # bit of a mess of if-else statements
45
+ def determine_dataset_type(
46
+ data_type: SupportedData,
47
+ in_memory: bool,
48
+ read_func: Optional[ReadFunc] = None,
49
+ image_stack_loader: Optional[ImageStackLoader] = None,
50
+ ) -> DatasetType:
51
+ """Determine what the dataset type should be based on the input arguments.
52
+
53
+ Parameters
54
+ ----------
55
+ data_type : SupportedData
56
+ The underlying datatype.
57
+ in_memory : bool
58
+ Whether all the data should be loaded into memory. This is argument is ignored
59
+ unless the `data_type` is "tiff" or "custom".
60
+ read_func : ReadFunc, optional
61
+ A function that can that can be used to load custom data. This argument is
62
+ ignored unless the `data_type` is "custom".
63
+ image_stack_loader : ImageStackLoader, optional
64
+ A function for custom image stack loading. This argument is ignored unless the
65
+ `data_type` is "custom".
66
+
67
+ Returns
68
+ -------
69
+ DatasetType
70
+ The Dataset type.
71
+
72
+ Raises
73
+ ------
74
+ NotImplementedError
75
+ For lazy-loading (`in_memory=False`) of a custom file type.
76
+ ValueError
77
+ If the `data_type` is "custom" but both `read_func` and `image_stack_loader` are
78
+ None.
79
+ ValueError
80
+ If the `data_type` is unrecognized.
81
+ """
82
+ if data_type == SupportedData.ARRAY:
83
+ # TODO: ignoring in_memory arg, error if False?
84
+ return DatasetType.ARRAY
85
+ elif data_type == SupportedData.TIFF:
86
+ if in_memory:
87
+ return DatasetType.IN_MEM_TIFF
88
+ else:
89
+ return DatasetType.LAZY_TIFF
90
+ elif data_type == SupportedData.CUSTOM:
91
+ if read_func is not None:
92
+ if in_memory:
93
+ return DatasetType.IN_MEM_CUSTOM_FILE
94
+ else:
95
+ raise NotImplementedError(
96
+ "Lazy loading has not been implemented for custom file types yet."
97
+ )
98
+ elif image_stack_loader is not None:
99
+ # TODO: ignoring im_memory arg
100
+ return DatasetType.CUSTOM_IMAGE_STACK
101
+ else:
102
+ raise ValueError(
103
+ "Found `data_type='custom'` but no `read_func` or `image_stack_loader` "
104
+ "has been provided."
105
+ )
106
+ # TODO: ZARR
107
+ else:
108
+ raise ValueError(f"Unrecognized `data_type`, '{data_type}'.")
109
+
110
+
111
+ # convenience function but should use `create_dataloader` function instead
112
+ # For lazy loading custom batch sampler also needs to be set.
113
+ def create_dataset(
114
+ config: Union[DataConfig, InferenceConfig],
115
+ mode: Mode,
116
+ inputs: Any,
117
+ targets: Any,
118
+ in_memory: bool,
119
+ read_func: Optional[ReadFunc] = None,
120
+ read_kwargs: Optional[dict[str, Any]] = None,
121
+ image_stack_loader: Optional[ImageStackLoader] = None,
122
+ image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
123
+ ) -> CareamicsDataset[ImageStack]:
124
+ """
125
+ Convenience function to create the CAREamicsDataset.
126
+
127
+ Parameters
128
+ ----------
129
+ config : DataConfig or InferenceConfig
130
+ The data configuration.
131
+ mode : Mode
132
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
133
+ inputs : Any
134
+ The input sources to the dataset.
135
+ targets : Any, optional
136
+ The target sources to the dataset.
137
+ in_memory : bool
138
+ Whether all the data should be loaded into memory. This is argument is ignored
139
+ unless the `data_type` in `config` is "tiff" or "custom".
140
+ read_func : ReadFunc, optional
141
+ A function that can that can be used to load custom data. This argument is
142
+ ignored unless the `data_type` in the `config` is "custom".
143
+ read_kwargs : dict of {str, Any}, optional
144
+ Additional key-word arguments to pass to the `read_func`.
145
+ image_stack_loader : ImageStackLoader, optional
146
+ A function for custom image stack loading. This argument is ignored unless the
147
+ `data_type` in the `config` is "custom".
148
+ image_stack_loader_kwargs : {str, Any}, optional
149
+ Additional key-word arguments to pass to the `image_stack_loader`.
150
+
151
+ Returns
152
+ -------
153
+ CareamicsDataset[ImageStack]
154
+ The CAREamicsDataset
155
+
156
+ Raises
157
+ ------
158
+ ValueError
159
+ For an unrecognized `data_type` in the `config`.
160
+ """
161
+ data_type = SupportedData(config.data_type)
162
+ dataset_type = determine_dataset_type(
163
+ data_type, in_memory, read_func, image_stack_loader
164
+ )
165
+ if dataset_type == DatasetType.ARRAY:
166
+ return create_array_dataset(config, mode, inputs, targets)
167
+ elif dataset_type == DatasetType.IN_MEM_TIFF:
168
+ return create_tiff_dataset(config, mode, inputs, targets)
169
+ # TODO: Lazy tiff
170
+ elif dataset_type == DatasetType.IN_MEM_CUSTOM_FILE:
171
+ if read_kwargs is None:
172
+ read_kwargs = {}
173
+ assert read_func is not None # should be true from `determine_dataset_type`
174
+ return create_custom_file_dataset(
175
+ config, mode, inputs, targets, read_func=read_func, read_kwargs=read_kwargs
176
+ )
177
+ elif dataset_type == DatasetType.CUSTOM_IMAGE_STACK:
178
+ if image_stack_loader_kwargs is None:
179
+ image_stack_loader_kwargs = {}
180
+ assert image_stack_loader is not None # should be true
181
+ return create_custom_image_stack_dataset(
182
+ config,
183
+ mode,
184
+ inputs,
185
+ targets,
186
+ image_stack_loader,
187
+ **image_stack_loader_kwargs,
188
+ )
189
+ else:
190
+ raise ValueError(f"Unrecognized dataset type, {dataset_type}.")
191
+
192
+
193
+ def create_array_dataset(
194
+ config: Union[DataConfig, InferenceConfig],
195
+ mode: Mode,
196
+ inputs: Sequence[NDArray[Any]],
197
+ targets: Optional[Sequence[NDArray[Any]]],
198
+ ) -> CareamicsDataset[InMemoryImageStack]:
199
+ """
200
+ Create a CAREamicsDataset from array data.
201
+
202
+ Parameters
203
+ ----------
204
+ config : DataConfig or InferenceConfig
205
+ The data configuration.
206
+ mode : Mode
207
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
208
+ inputs : Any
209
+ The input sources to the dataset.
210
+ targets : Any, optional
211
+ The target sources to the dataset.
212
+
213
+ Returns
214
+ -------
215
+ CareamicsDataset[InMemoryImageStack]
216
+ A CAREamicsDataset
217
+ """
218
+ input_extractor = create_array_extractor(source=inputs, axes=config.axes)
219
+ target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
220
+ if targets is not None:
221
+ target_extractor = create_array_extractor(source=targets, axes=config.axes)
222
+ else:
223
+ target_extractor = None
224
+ return CareamicsDataset(config, mode, input_extractor, target_extractor)
225
+
226
+
227
+ def create_tiff_dataset(
228
+ config: Union[DataConfig, InferenceConfig],
229
+ mode: Mode,
230
+ inputs: Sequence[Path],
231
+ targets: Optional[Sequence[Path]],
232
+ ) -> CareamicsDataset[InMemoryImageStack]:
233
+ """
234
+ Create a CAREamicsDataset from tiff files that will be all loaded into memory.
235
+
236
+ Parameters
237
+ ----------
238
+ config : DataConfig or InferenceConfig
239
+ The data configuration.
240
+ mode : Mode
241
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
242
+ inputs : Any
243
+ The input sources to the dataset.
244
+ targets : Any, optional
245
+ The target sources to the dataset.
246
+
247
+ Returns
248
+ -------
249
+ CareamicsDataset[InMemoryImageStack]
250
+ A CAREamicsDataset
251
+ """
252
+ input_extractor = create_tiff_extractor(
253
+ source=inputs,
254
+ axes=config.axes,
255
+ )
256
+ target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
257
+ if targets is not None:
258
+ target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
259
+ else:
260
+ target_extractor = None
261
+ dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
262
+ return dataset
263
+
264
+
265
+ def create_ome_zarr_dataset(
266
+ config: Union[DataConfig, InferenceConfig],
267
+ mode: Mode,
268
+ inputs: Sequence[Path],
269
+ targets: Optional[Sequence[Path]],
270
+ ) -> CareamicsDataset[ZarrImageStack]:
271
+ """
272
+ Create a dataset from OME ZARR files.
273
+
274
+ Parameters
275
+ ----------
276
+ config : DataConfig or InferenceConfig
277
+ The data configuration.
278
+ mode : Mode
279
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
280
+ inputs : Any
281
+ The input sources to the dataset.
282
+ targets : Any, optional
283
+ The target sources to the dataset.
284
+
285
+ Returns
286
+ -------
287
+ CareamicsDataset[ZarrImageStack]
288
+ A CAREamicsDataset
289
+ """
290
+
291
+ input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
292
+ target_extractor: Optional[PatchExtractor[ZarrImageStack]]
293
+ if targets is not None:
294
+ target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
295
+ else:
296
+ target_extractor = None
297
+ dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
298
+ return dataset
299
+
300
+
301
+ def create_custom_file_dataset(
302
+ config: Union[DataConfig, InferenceConfig],
303
+ mode: Mode,
304
+ inputs: Sequence[Path],
305
+ targets: Optional[Sequence[Path]],
306
+ *,
307
+ read_func: ReadFunc,
308
+ read_kwargs: dict[str, Any],
309
+ ) -> CareamicsDataset[InMemoryImageStack]:
310
+ """
311
+ Create a CAREamicsDataset from custom files that will be all loaded into memory.
312
+
313
+ Parameters
314
+ ----------
315
+ config : DataConfig or InferenceConfig
316
+ The data configuration.
317
+ mode : Mode
318
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
319
+ inputs : Any
320
+ The input sources to the dataset.
321
+ targets : Any, optional
322
+ The target sources to the dataset.
323
+ read_func : Optional[ReadFunc], optional
324
+ A function that can that can be used to load custom data. This argument is
325
+ ignored unless the `data_type` is "custom".
326
+ image_stack_loader : Optional[ImageStackLoader], optional
327
+ A function for custom image stack loading. This argument is ignored unless the
328
+ `data_type` is "custom".
329
+
330
+ Returns
331
+ -------
332
+ CareamicsDataset[InMemoryImageStack]
333
+ A CAREamicsDataset
334
+ """
335
+ input_extractor = create_custom_file_extractor(
336
+ source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
337
+ )
338
+ target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
339
+ if targets is not None:
340
+ target_extractor = create_custom_file_extractor(
341
+ source=targets,
342
+ axes=config.axes,
343
+ read_func=read_func,
344
+ read_kwargs=read_kwargs,
345
+ )
346
+ else:
347
+ target_extractor = None
348
+ dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
349
+ return dataset
350
+
351
+
352
+ def create_custom_image_stack_dataset(
353
+ config: Union[DataConfig, InferenceConfig],
354
+ mode: Mode,
355
+ inputs: Any,
356
+ targets: Optional[Any],
357
+ image_stack_loader: ImageStackLoader[P, GenericImageStack],
358
+ *args: P.args,
359
+ **kwargs: P.kwargs,
360
+ ) -> CareamicsDataset[GenericImageStack]:
361
+ """
362
+ Create a CAREamicsDataset from a custom `ImageStack` class.
363
+
364
+ The custom `ImageStack` class can be loaded using the `image_stack_loader` function.
365
+
366
+ Parameters
367
+ ----------
368
+ config : DataConfig or InferenceConfig
369
+ The data configuration.
370
+ mode : Mode
371
+ Whether to create the dataset in "training", "validation" or "predicting" mode.
372
+ inputs : Any
373
+ The input sources to the dataset.
374
+ targets : Any, optional
375
+ The target sources to the dataset.
376
+ image_stack_loader : ImageStackLoader
377
+ A function for custom image stack loading. This argument is ignored unless the
378
+ `data_type` is "custom".
379
+ *args : Any
380
+ Positional arguments to pass to the `image_stack_loader`.
381
+ **kwargs : Any
382
+ Key-word arguments to pass to the `image_stack_loader`.
383
+
384
+ Returns
385
+ -------
386
+ CareamicsDataset[GenericImageStack]
387
+ A CAREamicsDataset
388
+ """
389
+ input_extractor = create_custom_image_stack_extractor(
390
+ inputs,
391
+ config.axes,
392
+ image_stack_loader,
393
+ *args,
394
+ **kwargs,
395
+ )
396
+ target_extractor: Optional[PatchExtractor[GenericImageStack]]
397
+ if targets is not None:
398
+ target_extractor = create_custom_image_stack_extractor(
399
+ targets,
400
+ config.axes,
401
+ image_stack_loader,
402
+ *args,
403
+ **kwargs,
404
+ )
405
+ else:
406
+ target_extractor = None
407
+ dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
408
+ return dataset