careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset/dataset_utils/running_stats.py +7 -3
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -52,6 +52,9 @@ class CAREamist:
|
|
|
52
52
|
by default None.
|
|
53
53
|
callbacks : list of Callback, optional
|
|
54
54
|
List of callbacks to use during training and prediction, by default None.
|
|
55
|
+
enable_progress_bar : bool
|
|
56
|
+
Whether a progress bar will be displayed during training, validation and
|
|
57
|
+
prediction.
|
|
55
58
|
|
|
56
59
|
Attributes
|
|
57
60
|
----------
|
|
@@ -77,6 +80,7 @@ class CAREamist:
|
|
|
77
80
|
source: Union[Path, str],
|
|
78
81
|
work_dir: Optional[Union[Path, str]] = None,
|
|
79
82
|
callbacks: Optional[list[Callback]] = None,
|
|
83
|
+
enable_progress_bar: bool = True,
|
|
80
84
|
) -> None: ...
|
|
81
85
|
|
|
82
86
|
@overload
|
|
@@ -85,6 +89,7 @@ class CAREamist:
|
|
|
85
89
|
source: Configuration,
|
|
86
90
|
work_dir: Optional[Union[Path, str]] = None,
|
|
87
91
|
callbacks: Optional[list[Callback]] = None,
|
|
92
|
+
enable_progress_bar: bool = True,
|
|
88
93
|
) -> None: ...
|
|
89
94
|
|
|
90
95
|
def __init__(
|
|
@@ -92,6 +97,7 @@ class CAREamist:
|
|
|
92
97
|
source: Union[Path, str, Configuration],
|
|
93
98
|
work_dir: Optional[Union[Path, str]] = None,
|
|
94
99
|
callbacks: Optional[list[Callback]] = None,
|
|
100
|
+
enable_progress_bar: bool = True,
|
|
95
101
|
) -> None:
|
|
96
102
|
"""
|
|
97
103
|
Initialize CAREamist with a configuration object or a path.
|
|
@@ -112,6 +118,9 @@ class CAREamist:
|
|
|
112
118
|
by default None.
|
|
113
119
|
callbacks : list of Callback, optional
|
|
114
120
|
List of callbacks to use during training and prediction, by default None.
|
|
121
|
+
enable_progress_bar : bool
|
|
122
|
+
Whether a progress bar will be displayed during training, validation and
|
|
123
|
+
prediction.
|
|
115
124
|
|
|
116
125
|
Raises
|
|
117
126
|
------
|
|
@@ -169,7 +178,7 @@ class CAREamist:
|
|
|
169
178
|
self.model, self.cfg = load_pretrained(source)
|
|
170
179
|
|
|
171
180
|
# define the checkpoint saving callback
|
|
172
|
-
self._define_callbacks(callbacks)
|
|
181
|
+
self._define_callbacks(callbacks, enable_progress_bar)
|
|
173
182
|
|
|
174
183
|
# instantiate logger
|
|
175
184
|
csv_logger = CSVLogger(
|
|
@@ -202,7 +211,7 @@ class CAREamist:
|
|
|
202
211
|
precision=self.cfg.training_config.precision,
|
|
203
212
|
max_steps=self.cfg.training_config.max_steps,
|
|
204
213
|
check_val_every_n_epoch=self.cfg.training_config.check_val_every_n_epoch,
|
|
205
|
-
enable_progress_bar=
|
|
214
|
+
enable_progress_bar=enable_progress_bar,
|
|
206
215
|
accumulate_grad_batches=self.cfg.training_config.accumulate_grad_batches,
|
|
207
216
|
gradient_clip_val=self.cfg.training_config.gradient_clip_val,
|
|
208
217
|
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
|
|
@@ -215,13 +224,19 @@ class CAREamist:
|
|
|
215
224
|
self.train_datamodule: Optional[TrainDataModule] = None
|
|
216
225
|
self.pred_datamodule: Optional[PredictDataModule] = None
|
|
217
226
|
|
|
218
|
-
def _define_callbacks(
|
|
227
|
+
def _define_callbacks(
|
|
228
|
+
self, callbacks: Optional[list[Callback]], enable_progress_bar: bool
|
|
229
|
+
) -> None:
|
|
219
230
|
"""Define the callbacks for the training loop.
|
|
220
231
|
|
|
221
232
|
Parameters
|
|
222
233
|
----------
|
|
223
234
|
callbacks : list of Callback, optional
|
|
224
235
|
List of callbacks to use during training and prediction, by default None.
|
|
236
|
+
enable_progress_bar : bool
|
|
237
|
+
Whether a progress bar will be displayed during training, validation and
|
|
238
|
+
prediction. It controls whether a `ProgressBarCallback` is added to the
|
|
239
|
+
callback list.
|
|
225
240
|
"""
|
|
226
241
|
self.callbacks = [] if callbacks is None else callbacks
|
|
227
242
|
|
|
@@ -251,9 +266,10 @@ class CAREamist:
|
|
|
251
266
|
filename=self.cfg.experiment_name,
|
|
252
267
|
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
253
268
|
),
|
|
254
|
-
ProgressBarCallback(),
|
|
255
269
|
]
|
|
256
270
|
)
|
|
271
|
+
if enable_progress_bar:
|
|
272
|
+
self.callbacks.append(ProgressBarCallback())
|
|
257
273
|
|
|
258
274
|
# early stopping callback
|
|
259
275
|
if self.cfg.training_config.early_stopping_callback is not None:
|
|
@@ -4,10 +4,11 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from pprint import pformat
|
|
7
|
-
from typing import Any, Literal, Union
|
|
7
|
+
from typing import Any, Callable, Literal, Union
|
|
8
8
|
|
|
9
9
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
11
|
+
from pydantic.main import IncEx
|
|
11
12
|
from typing_extensions import Self
|
|
12
13
|
|
|
13
14
|
from careamics.config.algorithms import (
|
|
@@ -297,17 +298,18 @@ class Configuration(BaseModel):
|
|
|
297
298
|
self,
|
|
298
299
|
*,
|
|
299
300
|
mode: Literal["json", "python"] | str = "python",
|
|
300
|
-
include:
|
|
301
|
-
exclude:
|
|
301
|
+
include: IncEx | None = None,
|
|
302
|
+
exclude: IncEx | None = None,
|
|
302
303
|
context: Any | None = None,
|
|
303
|
-
by_alias: bool = False,
|
|
304
|
+
by_alias: bool | None = False,
|
|
304
305
|
exclude_unset: bool = False,
|
|
305
306
|
exclude_defaults: bool = False,
|
|
306
307
|
exclude_none: bool = True,
|
|
307
308
|
round_trip: bool = False,
|
|
308
309
|
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
310
|
+
fallback: Callable[[Any], Any] | None = None,
|
|
309
311
|
serialize_as_any: bool = False,
|
|
310
|
-
) -> dict:
|
|
312
|
+
) -> dict[str, Any]:
|
|
311
313
|
"""
|
|
312
314
|
Override model_dump method in order to set default values.
|
|
313
315
|
|
|
@@ -337,6 +339,8 @@ class Configuration(BaseModel):
|
|
|
337
339
|
representation.
|
|
338
340
|
warnings : bool | Literal['none', 'warn', 'error'], default=True
|
|
339
341
|
Whether to emit warnings.
|
|
342
|
+
fallback : Callable[[Any], Any] | None, default=None
|
|
343
|
+
A function to call when an unknown value is encountered.
|
|
340
344
|
serialize_as_any : bool, default=False
|
|
341
345
|
Whether to serialize all types as Any.
|
|
342
346
|
|
|
@@ -356,6 +360,7 @@ class Configuration(BaseModel):
|
|
|
356
360
|
exclude_none=exclude_none,
|
|
357
361
|
round_trip=round_trip,
|
|
358
362
|
warnings=warnings,
|
|
363
|
+
fallback=fallback,
|
|
359
364
|
serialize_as_any=serialize_as_any,
|
|
360
365
|
)
|
|
361
366
|
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
5
7
|
from collections.abc import Sequence
|
|
6
8
|
from pprint import pformat
|
|
7
9
|
from typing import Annotated, Any, Literal, Optional, Union
|
|
@@ -143,7 +145,7 @@ class DataConfig(BaseModel):
|
|
|
143
145
|
should include the `shuffle` key, which is set to `True` by default. We strongly
|
|
144
146
|
recommend to keep it as `True` to ensure the best training results."""
|
|
145
147
|
|
|
146
|
-
val_dataloader_params: dict[str, Any] = Field(default={})
|
|
148
|
+
val_dataloader_params: dict[str, Any] = Field(default={}, validate_default=True)
|
|
147
149
|
"""Dictionary of PyTorch validation dataloader parameters."""
|
|
148
150
|
|
|
149
151
|
@field_validator("patch_size")
|
|
@@ -210,6 +212,41 @@ class DataConfig(BaseModel):
|
|
|
210
212
|
|
|
211
213
|
return axes
|
|
212
214
|
|
|
215
|
+
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
216
|
+
@classmethod
|
|
217
|
+
def set_default_dataloader_params(
|
|
218
|
+
cls, dataloader_params: dict[str, Any]
|
|
219
|
+
) -> dict[str, Any]:
|
|
220
|
+
"""
|
|
221
|
+
Set default dataloader parameters if not provided.
|
|
222
|
+
|
|
223
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
224
|
+
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
dataloader_params : dict of {str: Any}
|
|
229
|
+
The dataloader parameters.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
dict of {str: Any}
|
|
234
|
+
The dataloader parameters with defaults applied.
|
|
235
|
+
"""
|
|
236
|
+
if "num_workers" not in dataloader_params:
|
|
237
|
+
# Use 1 worker during tests, otherwise use all available CPU cores
|
|
238
|
+
if "pytest" in sys.modules:
|
|
239
|
+
dataloader_params["num_workers"] = 0
|
|
240
|
+
else:
|
|
241
|
+
dataloader_params["num_workers"] = os.cpu_count()
|
|
242
|
+
|
|
243
|
+
if "pin_memory" not in dataloader_params:
|
|
244
|
+
import torch
|
|
245
|
+
|
|
246
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
247
|
+
|
|
248
|
+
return dataloader_params
|
|
249
|
+
|
|
213
250
|
@field_validator("train_dataloader_params")
|
|
214
251
|
@classmethod
|
|
215
252
|
def shuffle_train_dataloader(
|
|
@@ -51,9 +51,7 @@ class OptimizerModel(BaseModel):
|
|
|
51
51
|
|
|
52
52
|
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
53
53
|
parameters: dict = Field(
|
|
54
|
-
default={
|
|
55
|
-
"lr": 1e-4,
|
|
56
|
-
},
|
|
54
|
+
default={},
|
|
57
55
|
validate_default=True,
|
|
58
56
|
)
|
|
59
57
|
"""Parameters of the optimizer, see PyTorch documentation for more details."""
|
|
@@ -39,8 +39,6 @@ class TrainingConfig(BaseModel):
|
|
|
39
39
|
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
40
|
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
41
|
"""Validation step frequency."""
|
|
42
|
-
enable_progress_bar: bool = Field(default=True)
|
|
43
|
-
"""Whether to enable the progress bar."""
|
|
44
42
|
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
45
43
|
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
46
44
|
gradient_clip_val: Optional[Union[int, float]] = None
|
|
@@ -21,9 +21,13 @@ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
|
|
|
21
21
|
tuple of (list of floats, list of floats)
|
|
22
22
|
Lists of mean and standard deviation values per channel.
|
|
23
23
|
"""
|
|
24
|
-
# Define the
|
|
25
|
-
|
|
26
|
-
|
|
24
|
+
# Define the lists for storing mean and std values
|
|
25
|
+
means, stds = [], []
|
|
26
|
+
# Iterate over the channels dimension and compute mean and std
|
|
27
|
+
for ax in range(image.shape[1]):
|
|
28
|
+
means.append(image[:, ax, ...].mean())
|
|
29
|
+
stds.append(image[:, ax, ...].std())
|
|
30
|
+
return np.stack(means), np.stack(stds)
|
|
27
31
|
|
|
28
32
|
|
|
29
33
|
def update_iterative_stats(
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# The CAREamics Dataset
|
|
2
|
+
|
|
3
|
+
Welcome to the CAREamics dataset!
|
|
4
|
+
|
|
5
|
+
A PyTorch based dataset, designed to be used with microscopy data. It is universal for the training, validation and prediction stages of a machine learning pipeline.
|
|
6
|
+
|
|
7
|
+
The key ethos is to create a modular and maintainable dataset comprised of swappable components that interact through interfaces. This should facilitate a smooth development process when extending the dataset's function to new features, and also enable advanced users to easily customize the dataset to their needs, by writing custom components. This is achieved by following a few key software engineering principles, detailed at the end of this README file.
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
## Dataset Component overview
|
|
11
|
+
|
|
12
|
+
```mermaid
|
|
13
|
+
---
|
|
14
|
+
title: CAREamicsDataset
|
|
15
|
+
---
|
|
16
|
+
classDiagram
|
|
17
|
+
class CAREamicsDataset{
|
|
18
|
+
+PatchExtractor input_extractor
|
|
19
|
+
+Optional[PatchExtractor] target_extractor
|
|
20
|
+
+PatchingStrategy patching_strategy
|
|
21
|
+
+list~Transform~ transforms
|
|
22
|
+
+\_\_getitem\_\_(int index) NDArray
|
|
23
|
+
}
|
|
24
|
+
class PatchingStrategy{
|
|
25
|
+
<<interface>>
|
|
26
|
+
+n_patches int
|
|
27
|
+
+get_patch_spec(index: int) PatchSpecs
|
|
28
|
+
}
|
|
29
|
+
class RandomPatchingStrategy{
|
|
30
|
+
}
|
|
31
|
+
class FixedRandomPatchingStrategy{
|
|
32
|
+
}
|
|
33
|
+
class SequentialPatchingStrategy{
|
|
34
|
+
}
|
|
35
|
+
class TilingStrategy{
|
|
36
|
+
+get_patch_spec(index: int) TileSpecs
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
class PatchExtractor{
|
|
40
|
+
+list~ImageStack~ image_stacks
|
|
41
|
+
+extract_patch(PatchSpecs) NDArray
|
|
42
|
+
}
|
|
43
|
+
class PatchSpecs {
|
|
44
|
+
<<TypedDict>>
|
|
45
|
+
+int data_idx
|
|
46
|
+
+int sample_idx
|
|
47
|
+
+Sequence~int~ coords
|
|
48
|
+
+Sequence~int~ patch_size
|
|
49
|
+
}
|
|
50
|
+
class TileSpecs {
|
|
51
|
+
<<TypedDict>>
|
|
52
|
+
+Sequence~int~ crop_coords
|
|
53
|
+
+Sequence~int~ crop_size
|
|
54
|
+
+Sequence~int~ stitch_coords
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
class ImageStack{
|
|
58
|
+
<<interface>>
|
|
59
|
+
+Union[Path, Literal["array"]] source
|
|
60
|
+
+Sequence~int~ data_shape
|
|
61
|
+
+DTypeLike data_type
|
|
62
|
+
+extract_patch(sample_idx, coords, patch_size) NDArray
|
|
63
|
+
}
|
|
64
|
+
class InMemoryImageStack {
|
|
65
|
+
}
|
|
66
|
+
class ZarrImageStack {
|
|
67
|
+
+Path source
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
CAREamicsDataset --* PatchExtractor: Is composed of
|
|
71
|
+
CAREamicsDataset --* PatchingStrategy: Is composed of
|
|
72
|
+
PatchExtractor --o ImageStack: Aggregates
|
|
73
|
+
ImageStack <|-- InMemoryImageStack: Implements
|
|
74
|
+
ImageStack <|-- ZarrImageStack: Implements
|
|
75
|
+
PatchingStrategy <|-- RandomPatchingStrategy: Implements
|
|
76
|
+
PatchingStrategy <|-- FixedRandomPatchingStrategy: Implements
|
|
77
|
+
PatchingStrategy <|-- SequentialPatchingStrategy: Implements
|
|
78
|
+
PatchingStrategy <|-- TilingStrategy: Implements
|
|
79
|
+
PatchSpecs <|-- TileSpecs: Inherits from
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### `ImageStack` and implementations
|
|
83
|
+
|
|
84
|
+
This interface represents a set of image data, which can be saved with any subset of the
|
|
85
|
+
axes STCZYX, in any order, see below for a description of the dimensions. The `ImageStack`
|
|
86
|
+
interface's job is to act as an adapter for different data storage types, so that higher
|
|
87
|
+
level classes can access the image data without having to know the implementation details of
|
|
88
|
+
how to load or read data from each storage type. This means we can decide to support new storage
|
|
89
|
+
types by implementing a new concrete `ImageStack` class without having to change anything
|
|
90
|
+
in the `CAREamistDataset` class. Advanced users can also choose to create their own
|
|
91
|
+
`ImageStack` class if they want to work with their own data storage type.
|
|
92
|
+
|
|
93
|
+
The interface provides an `extract_patch` method which will produce a patch from the image,
|
|
94
|
+
as a NumPy array, with the dimensions C(Z)YX. This method should be thought of as simply
|
|
95
|
+
a wrapper for the equivalent to NumPy slicing for each of the storage types.
|
|
96
|
+
|
|
97
|
+
#### Concrete implementations
|
|
98
|
+
|
|
99
|
+
- `InMemoryImageStack`: The underlying data is stored as a NumPy array in memory. It has some
|
|
100
|
+
additional constructor methods to load the data from known file formats such as TIFF files.
|
|
101
|
+
- `ZarrImageStack`: The underlying data is stored as a ZARR file on disk.
|
|
102
|
+
|
|
103
|
+
#### Axes description
|
|
104
|
+
|
|
105
|
+
- S is a generic sample dimension,
|
|
106
|
+
- T is a time dimension,
|
|
107
|
+
- C is a channel dimension,
|
|
108
|
+
- Z is a spatial dimension,
|
|
109
|
+
- Y is a spatial dimension,
|
|
110
|
+
- X is a spatial dimension.
|
|
111
|
+
|
|
112
|
+
### `PatchExtractor`
|
|
113
|
+
|
|
114
|
+
The `PatchExtractor` class aggregates many `ImageStack` instances, this allows for multiple
|
|
115
|
+
images with different dimensions, and possibly different storage types to be treated as a single entity.
|
|
116
|
+
The class has an `extract_patch` method to extract a patch from any one of its `ImageStack`
|
|
117
|
+
objects. It can also possibly be extended when extra logic to extract patches is needed,
|
|
118
|
+
for example when constructing lateral-context inputs for the MicroSplit LVAE models.
|
|
119
|
+
|
|
120
|
+
### `PatchingStrategy`
|
|
121
|
+
|
|
122
|
+
The `PatchingStrategy` class is an interface to generate patch specifications, where each of the
|
|
123
|
+
concrete implementations produce a set of patch specifications using a different strategy.
|
|
124
|
+
|
|
125
|
+
It has a `n_patches` attribute that can be accessed to find out how many patches the
|
|
126
|
+
strategy will produce, given the shapes of the image stacks it has been initialized with.
|
|
127
|
+
This is needed by the `CAREamicsDataset` to return its length.
|
|
128
|
+
|
|
129
|
+
Most importantly it has a `get_patch_spec` method, that takes an index and returns a
|
|
130
|
+
patch specification. For deterministic patching strategies, this method will always
|
|
131
|
+
return the same patch specification given the same index, but there are also random strategies
|
|
132
|
+
where the returned patch specification will change every time. The given index can never
|
|
133
|
+
be greater than `n_patches`.
|
|
134
|
+
|
|
135
|
+
#### Concrete implementations
|
|
136
|
+
|
|
137
|
+
- `RandomPatchingStrategy`: this strategy will produce random patches that will change
|
|
138
|
+
even if the `extract_patch` method is called with the same index.
|
|
139
|
+
- `FixedRandomPatchingStrategy`: this strategy will produce random patches, but the patch
|
|
140
|
+
will be the same if the `extract_patch` method is called with the same index. This is
|
|
141
|
+
useful for making sure validation is comparable epoch to epoch.
|
|
142
|
+
- `SequentialPatchingStrategy`: this strategy is deterministic and the patches will be
|
|
143
|
+
sequential with some specified overlap.
|
|
144
|
+
- `TilingStrategy`: this strategy is deterministic and the patches will be
|
|
145
|
+
sequential with some specified overlap. Rather than a `PatchSpecs` dictionary it will
|
|
146
|
+
produce a `TileSpecs` dictionary which includes some extra fields that are used for
|
|
147
|
+
stitching the tiles back together.
|
|
148
|
+
|
|
149
|
+
#### PatchSpecs
|
|
150
|
+
|
|
151
|
+
The `get_patch_spec` returns a dictionary containing the keys `data_idx`, `sample_idx`, `coords` and `patch_size`.
|
|
152
|
+
These are the exact arguments that the `PatchExtractor.extract_patch` method takes. The patch specification
|
|
153
|
+
produced by the patching strategy is received by the `PatchExtractor` to in-turn produce an image patch.
|
|
154
|
+
|
|
155
|
+
For type hinting, `PatchSpecs` is defined as a `TypedDict`.
|
|
156
|
+
|
|
157
|
+
## Key Principles
|
|
158
|
+
|
|
159
|
+
The aim of all these principles is to create a system of interacting classes that have
|
|
160
|
+
low coupling. This allows for one section to be changed or extended without breaking functionality
|
|
161
|
+
elsewhere in the codebase.
|
|
162
|
+
|
|
163
|
+
### Composition over inheritance
|
|
164
|
+
|
|
165
|
+
The principle of composition over inheritance is: rather than using inheritance to
|
|
166
|
+
extend or change the behavior of a class, instead, a class can be composed of modules
|
|
167
|
+
that can be swapped to extend or change behavior.
|
|
168
|
+
|
|
169
|
+
The reason to use composition is that it promotes the easy reuse of the underlying
|
|
170
|
+
components, it can prevent a subclass explosion, and it leads to a maintainable and
|
|
171
|
+
easily extendable design. A software architecture based on composition is normally
|
|
172
|
+
maintainable and extendable because if a component needs to change then the whole class
|
|
173
|
+
shouldn't have to be refactored and if a new feature needs to be added, usually an additional
|
|
174
|
+
component can be added to the class.
|
|
175
|
+
|
|
176
|
+
The `CAREamicsDataset` is composed of `PatchExtractor` and `PatchingStrategy` and `Transfrom` components.
|
|
177
|
+
The `PatchingStrategy` classes implement an interface so the dataset can switch between
|
|
178
|
+
different strategies. The `PatchExtractor` is composed of many `ImageStack` instances,
|
|
179
|
+
new image stacks can be added to extend the type of data that the dataset can read from.
|
|
180
|
+
|
|
181
|
+
### Dependency Inversion
|
|
182
|
+
|
|
183
|
+
The dependency inversion principle states:
|
|
184
|
+
|
|
185
|
+
1. High-level modules should not depend on low-level modules. Both high-level and
|
|
186
|
+
low-level modules should depend on abstractions (e.g. interfaces).
|
|
187
|
+
2. Abstractions should not depend on details (concrete implementations). Details should
|
|
188
|
+
depend on abstractions.
|
|
189
|
+
|
|
190
|
+
In other words high level modules that provide complex logic should be easily reusable
|
|
191
|
+
and not depend on implementation details of low-level modules that provide utility functionality.
|
|
192
|
+
This can be achieved by introducing abstractions that decouple high and low level modules.
|
|
193
|
+
|
|
194
|
+
An example of the dependency inversion principle in use is how the `PatchExtractor` only
|
|
195
|
+
depends on the `ImageStack` interface, and does not have to have any knowledge of the
|
|
196
|
+
concrete implementations. The concrete `ImageStack` implementations also do not have
|
|
197
|
+
any knowledge of the `PatchExtractor` or any other higher-level functionality that the
|
|
198
|
+
dataset needs.
|
|
199
|
+
|
|
200
|
+
### Single Responsibility Principle
|
|
201
|
+
|
|
202
|
+
Each component should have a small scope of responsibility that is easily defined. This
|
|
203
|
+
should make the code easier to maintain and hopefully reduce the number of places in the
|
|
204
|
+
code that have to change when introducing a new feature.
|
|
205
|
+
|
|
206
|
+
- `ImageStack` responsibility: to act as an adapter for loading and reading image data
|
|
207
|
+
from different underlying storage.
|
|
208
|
+
- `PatchExtractor` responsibility: to extract patches from a set of image stacks.
|
|
209
|
+
- `PatchingStrategy` responsibility: to produce patch specifications given an index, through
|
|
210
|
+
an interface that hides the underlying implementation.
|
|
211
|
+
- `CAREamicsDataset` responsibility: to orchestrate the interactions of its underlying
|
|
212
|
+
components to produce an input patch (and target patch when required) given an index.
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Generic, Literal, NamedTuple, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.config.transformations import NormalizeModel
|
|
13
|
+
from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
|
|
14
|
+
from careamics.dataset.patching.patching import Stats
|
|
15
|
+
from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
|
|
16
|
+
from careamics.dataset_ng.patching_strategies import (
|
|
17
|
+
FixedRandomPatchingStrategy,
|
|
18
|
+
PatchingStrategy,
|
|
19
|
+
PatchSpecs,
|
|
20
|
+
RandomPatchingStrategy,
|
|
21
|
+
TilingStrategy,
|
|
22
|
+
WholeSamplePatchingStrategy,
|
|
23
|
+
)
|
|
24
|
+
from careamics.transforms import Compose
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Mode(str, Enum):
|
|
28
|
+
TRAINING = "training"
|
|
29
|
+
VALIDATING = "validating"
|
|
30
|
+
PREDICTING = "predicting"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ImageRegionData(NamedTuple):
|
|
34
|
+
data: NDArray
|
|
35
|
+
source: Union[str, Literal["array"]]
|
|
36
|
+
data_shape: Sequence[int]
|
|
37
|
+
dtype: str # dtype should be str for collate
|
|
38
|
+
axes: str
|
|
39
|
+
region_spec: PatchSpecs
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
49
|
+
mode: Mode,
|
|
50
|
+
input_extractor: PatchExtractor[GenericImageStack],
|
|
51
|
+
target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
|
|
52
|
+
):
|
|
53
|
+
self.config = data_config
|
|
54
|
+
self.mode = mode
|
|
55
|
+
|
|
56
|
+
self.input_extractor = input_extractor
|
|
57
|
+
self.target_extractor = target_extractor
|
|
58
|
+
|
|
59
|
+
self.patching_strategy = self._initialize_patching_strategy()
|
|
60
|
+
|
|
61
|
+
self.input_stats, self.target_stats = self._initialize_statistics()
|
|
62
|
+
|
|
63
|
+
self.transforms = self._initialize_transforms()
|
|
64
|
+
|
|
65
|
+
def _initialize_patching_strategy(self) -> PatchingStrategy:
|
|
66
|
+
patching_strategy: PatchingStrategy
|
|
67
|
+
if self.mode == Mode.TRAINING:
|
|
68
|
+
if isinstance(self.config, InferenceConfig):
|
|
69
|
+
raise ValueError("Inference config cannot be used for training.")
|
|
70
|
+
patching_strategy = RandomPatchingStrategy(
|
|
71
|
+
data_shapes=self.input_extractor.shape,
|
|
72
|
+
patch_size=self.config.patch_size,
|
|
73
|
+
# TODO: Add random seed to dataconfig
|
|
74
|
+
seed=getattr(self.config, "random_seed", 42),
|
|
75
|
+
)
|
|
76
|
+
elif self.mode == Mode.VALIDATING:
|
|
77
|
+
if isinstance(self.config, InferenceConfig):
|
|
78
|
+
raise ValueError("Inference config cannot be used for validating.")
|
|
79
|
+
patching_strategy = FixedRandomPatchingStrategy(
|
|
80
|
+
data_shapes=self.input_extractor.shape,
|
|
81
|
+
patch_size=self.config.patch_size,
|
|
82
|
+
# TODO: Add random seed to dataconfig
|
|
83
|
+
seed=getattr(self.config, "random_seed", 42),
|
|
84
|
+
)
|
|
85
|
+
elif self.mode == Mode.PREDICTING:
|
|
86
|
+
if not isinstance(self.config, InferenceConfig):
|
|
87
|
+
raise ValueError("Inference config must be used for predicting.")
|
|
88
|
+
if (self.config.tile_size is not None) and (
|
|
89
|
+
self.config.tile_overlap is not None
|
|
90
|
+
):
|
|
91
|
+
patching_strategy = TilingStrategy(
|
|
92
|
+
data_shapes=self.input_extractor.shape,
|
|
93
|
+
tile_size=self.config.tile_size,
|
|
94
|
+
overlaps=self.config.tile_overlap,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
patching_strategy = WholeSamplePatchingStrategy(
|
|
98
|
+
data_shapes=self.input_extractor.shape
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unrecognised dataset mode {self.mode}.")
|
|
102
|
+
|
|
103
|
+
return patching_strategy
|
|
104
|
+
|
|
105
|
+
def _initialize_transforms(self) -> Optional[Compose]:
|
|
106
|
+
if isinstance(self.config, DataConfig):
|
|
107
|
+
if self.mode == Mode.TRAINING:
|
|
108
|
+
# TODO: initialize normalization separately depending on configuration
|
|
109
|
+
return Compose(
|
|
110
|
+
transform_list=[
|
|
111
|
+
NormalizeModel(
|
|
112
|
+
image_means=self.input_stats.means,
|
|
113
|
+
image_stds=self.input_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
+ list(self.config.transforms)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# TODO: add TTA
|
|
122
|
+
return Compose(
|
|
123
|
+
transform_list=[
|
|
124
|
+
NormalizeModel(
|
|
125
|
+
image_means=self.input_stats.means,
|
|
126
|
+
image_stds=self.input_stats.stds,
|
|
127
|
+
target_means=self.target_stats.means,
|
|
128
|
+
target_stds=self.target_stats.stds,
|
|
129
|
+
)
|
|
130
|
+
]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def _calculate_stats(
|
|
134
|
+
self, data_extractor: PatchExtractor[GenericImageStack]
|
|
135
|
+
) -> Stats:
|
|
136
|
+
image_stats = WelfordStatistics()
|
|
137
|
+
n_patches = self.patching_strategy.n_patches
|
|
138
|
+
|
|
139
|
+
for idx in tqdm(range(n_patches), desc="Computing statistics"):
|
|
140
|
+
patch_spec = self.patching_strategy.get_patch_spec(idx)
|
|
141
|
+
patch = data_extractor.extract_patch(
|
|
142
|
+
data_idx=patch_spec["data_idx"],
|
|
143
|
+
sample_idx=patch_spec["sample_idx"],
|
|
144
|
+
coords=patch_spec["coords"],
|
|
145
|
+
patch_size=patch_spec["patch_size"],
|
|
146
|
+
)
|
|
147
|
+
# TODO: statistics accept SCYX format, while patch is CYX
|
|
148
|
+
image_stats.update(patch[None, ...], sample_idx=idx)
|
|
149
|
+
|
|
150
|
+
image_means, image_stds = image_stats.finalize()
|
|
151
|
+
return Stats(image_means, image_stds)
|
|
152
|
+
|
|
153
|
+
# TODO: add running stats
|
|
154
|
+
def _initialize_statistics(self) -> tuple[Stats, Stats]:
|
|
155
|
+
if self.config.image_means is not None and self.config.image_stds is not None:
|
|
156
|
+
input_stats = Stats(self.config.image_means, self.config.image_stds)
|
|
157
|
+
else:
|
|
158
|
+
input_stats = self._calculate_stats(self.input_extractor)
|
|
159
|
+
|
|
160
|
+
target_stats = Stats((), ())
|
|
161
|
+
if isinstance(self.config, DataConfig):
|
|
162
|
+
if (
|
|
163
|
+
self.config.target_means is not None
|
|
164
|
+
and self.config.target_stds is not None
|
|
165
|
+
):
|
|
166
|
+
target_stats = Stats(self.config.target_means, self.config.target_stds)
|
|
167
|
+
elif self.target_extractor is not None:
|
|
168
|
+
target_stats = self._calculate_stats(self.target_extractor)
|
|
169
|
+
|
|
170
|
+
return input_stats, target_stats
|
|
171
|
+
|
|
172
|
+
def __len__(self):
|
|
173
|
+
return self.patching_strategy.n_patches
|
|
174
|
+
|
|
175
|
+
def _create_image_region(
|
|
176
|
+
self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
|
|
177
|
+
) -> ImageRegionData:
|
|
178
|
+
data_idx = patch_spec["data_idx"]
|
|
179
|
+
source = extractor.image_stacks[data_idx].source
|
|
180
|
+
return ImageRegionData(
|
|
181
|
+
data=patch,
|
|
182
|
+
source=str(source),
|
|
183
|
+
dtype=str(extractor.image_stacks[data_idx].data_dtype),
|
|
184
|
+
data_shape=extractor.image_stacks[data_idx].data_shape,
|
|
185
|
+
# TODO: should it be axes of the original image instead?
|
|
186
|
+
axes=self.config.axes,
|
|
187
|
+
region_spec=patch_spec,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def __getitem__(
|
|
191
|
+
self, index: int
|
|
192
|
+
) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
|
|
193
|
+
patch_spec = self.patching_strategy.get_patch_spec(index)
|
|
194
|
+
input_patch = self.input_extractor.extract_patch(
|
|
195
|
+
data_idx=patch_spec["data_idx"],
|
|
196
|
+
sample_idx=patch_spec["sample_idx"],
|
|
197
|
+
coords=patch_spec["coords"],
|
|
198
|
+
patch_size=patch_spec["patch_size"],
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
target_patch = (
|
|
202
|
+
self.target_extractor.extract_patch(
|
|
203
|
+
data_idx=patch_spec["data_idx"],
|
|
204
|
+
sample_idx=patch_spec["sample_idx"],
|
|
205
|
+
coords=patch_spec["coords"],
|
|
206
|
+
patch_size=patch_spec["patch_size"],
|
|
207
|
+
)
|
|
208
|
+
if self.target_extractor is not None
|
|
209
|
+
else None
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if self.transforms is not None:
|
|
213
|
+
if self.target_extractor is not None:
|
|
214
|
+
input_patch, target_patch = self.transforms(input_patch, target_patch)
|
|
215
|
+
else:
|
|
216
|
+
# TODO: compose doesn't return None for target patch anymore
|
|
217
|
+
# so have to do this annoying if else
|
|
218
|
+
(input_patch,) = self.transforms(input_patch, target_patch)
|
|
219
|
+
target_patch = None
|
|
220
|
+
|
|
221
|
+
input_data = self._create_image_region(
|
|
222
|
+
patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if target_patch is not None and self.target_extractor is not None:
|
|
226
|
+
target_data = self._create_image_region(
|
|
227
|
+
patch=target_patch,
|
|
228
|
+
patch_spec=patch_spec,
|
|
229
|
+
extractor=self.target_extractor,
|
|
230
|
+
)
|
|
231
|
+
return input_data, target_data
|
|
232
|
+
else:
|
|
233
|
+
return (input_data,)
|