careamics 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/RECORD +73 -62
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""Data configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from typing import Annotated, Any, Literal, Optional, Union
|
|
8
|
+
from warnings import warn
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from numpy.typing import NDArray
|
|
12
|
+
from pydantic import (
|
|
13
|
+
BaseModel,
|
|
14
|
+
ConfigDict,
|
|
15
|
+
Field,
|
|
16
|
+
PlainSerializer,
|
|
17
|
+
field_validator,
|
|
18
|
+
model_validator,
|
|
19
|
+
)
|
|
20
|
+
from typing_extensions import Self
|
|
21
|
+
|
|
22
|
+
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
23
|
+
from ..validators import check_axes_validity
|
|
24
|
+
from .patching_strategies import (
|
|
25
|
+
RandomPatchingModel,
|
|
26
|
+
TiledPatchingModel,
|
|
27
|
+
WholePatchingModel,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# TODO: Validate the specific sizes of tiles and overlaps given UNet constraints
|
|
31
|
+
# - needs to be done in the Configuration
|
|
32
|
+
# - patches and overlaps sizes must also be checked against dimensionality
|
|
33
|
+
|
|
34
|
+
# TODO: is 3D updated anywhere in the code in CAREamist/downstream?
|
|
35
|
+
# - this will be important when swapping the data config in Configuration
|
|
36
|
+
# - `set_3D` currently not implemented here
|
|
37
|
+
# TODO: we can't tell that the patching strategy is correct
|
|
38
|
+
# - or is the responsibility of the creator (e.g. conveneince functions)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def np_float_to_scientific_str(x: float) -> str:
|
|
42
|
+
"""Return a string scientific representation of a float.
|
|
43
|
+
|
|
44
|
+
In particular, this method is used to serialize floats to strings, allowing
|
|
45
|
+
numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
x : float
|
|
50
|
+
Input value.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
str
|
|
55
|
+
Scientific string representation of the input value.
|
|
56
|
+
"""
|
|
57
|
+
return np.format_float_scientific(x, precision=7)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
|
|
61
|
+
"""Annotated float type, used to serialize floats to strings."""
|
|
62
|
+
|
|
63
|
+
PatchingStrategies = Union[
|
|
64
|
+
RandomPatchingModel,
|
|
65
|
+
# SequentialPatchingModel, # not supported yet
|
|
66
|
+
TiledPatchingModel,
|
|
67
|
+
WholePatchingModel,
|
|
68
|
+
]
|
|
69
|
+
"""Patching strategies."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class NGDataConfig(BaseModel):
|
|
73
|
+
"""Next-Generation Dataset configuration.
|
|
74
|
+
|
|
75
|
+
NGDataConfig are used for both training and prediction, with the patching strategy
|
|
76
|
+
determining how the data is processed. Note that `random` is the only patching
|
|
77
|
+
strategy compatible with training, while `tiled` and `whole` are only used for
|
|
78
|
+
prediction.
|
|
79
|
+
|
|
80
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
81
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
82
|
+
Prefer instead `set_means_and_stds` to set both at once. Means and stds are expected
|
|
83
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
84
|
+
of the target could be different from the input data.
|
|
85
|
+
|
|
86
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
# Pydantic class configuration
|
|
90
|
+
model_config = ConfigDict(
|
|
91
|
+
validate_assignment=True,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Dataset configuration
|
|
95
|
+
data_type: Literal["array", "tiff", "zarr", "custom"]
|
|
96
|
+
"""Type of input data."""
|
|
97
|
+
|
|
98
|
+
axes: str
|
|
99
|
+
"""Axes of the data, as defined in SupportedAxes."""
|
|
100
|
+
|
|
101
|
+
patching: PatchingStrategies = Field(..., discriminator="name")
|
|
102
|
+
"""Patching strategy to use. Note that `random` is the only supported strategy for
|
|
103
|
+
training, while `tiled` and `whole` are only used for prediction."""
|
|
104
|
+
|
|
105
|
+
# Optional fields
|
|
106
|
+
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
107
|
+
"""Batch size for training."""
|
|
108
|
+
|
|
109
|
+
image_means: Optional[list[Float]] = Field(
|
|
110
|
+
default=None, min_length=0, max_length=32
|
|
111
|
+
)
|
|
112
|
+
"""Means of the data across channels, used for normalization."""
|
|
113
|
+
|
|
114
|
+
image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
|
|
115
|
+
"""Standard deviations of the data across channels, used for normalization."""
|
|
116
|
+
|
|
117
|
+
target_means: Optional[list[Float]] = Field(
|
|
118
|
+
default=None, min_length=0, max_length=32
|
|
119
|
+
)
|
|
120
|
+
"""Means of the target data across channels, used for normalization."""
|
|
121
|
+
|
|
122
|
+
target_stds: Optional[list[Float]] = Field(
|
|
123
|
+
default=None, min_length=0, max_length=32
|
|
124
|
+
)
|
|
125
|
+
"""Standard deviations of the target data across channels, used for
|
|
126
|
+
normalization."""
|
|
127
|
+
|
|
128
|
+
transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
|
|
129
|
+
default=(
|
|
130
|
+
XYFlipModel(),
|
|
131
|
+
XYRandomRotate90Model(),
|
|
132
|
+
),
|
|
133
|
+
validate_default=True,
|
|
134
|
+
)
|
|
135
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
136
|
+
in SupportedTransform."""
|
|
137
|
+
|
|
138
|
+
train_dataloader_params: dict[str, Any] = Field(
|
|
139
|
+
default={"shuffle": True}, validate_default=True
|
|
140
|
+
)
|
|
141
|
+
"""Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
|
|
142
|
+
should include the `shuffle` key, which is set to `True` by default. We strongly
|
|
143
|
+
recommend to keep it as `True` to ensure the best training results."""
|
|
144
|
+
|
|
145
|
+
val_dataloader_params: dict[str, Any] = Field(default={})
|
|
146
|
+
"""Dictionary of PyTorch validation dataloader parameters."""
|
|
147
|
+
|
|
148
|
+
test_dataloader_params: dict[str, Any] = Field(default={})
|
|
149
|
+
"""Dictionary of PyTorch test dataloader parameters."""
|
|
150
|
+
|
|
151
|
+
seed: Optional[int] = Field(default=None, gt=0)
|
|
152
|
+
"""Random seed for reproducibility."""
|
|
153
|
+
|
|
154
|
+
@field_validator("axes")
|
|
155
|
+
@classmethod
|
|
156
|
+
def axes_valid(cls, axes: str) -> str:
|
|
157
|
+
"""
|
|
158
|
+
Validate axes.
|
|
159
|
+
|
|
160
|
+
Axes must:
|
|
161
|
+
- be a combination of 'STCZYX'
|
|
162
|
+
- not contain duplicates
|
|
163
|
+
- contain at least 2 contiguous axes: X and Y
|
|
164
|
+
- contain at most 4 axes
|
|
165
|
+
- not contain both S and T axes
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
axes : str
|
|
170
|
+
Axes to validate.
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
str
|
|
175
|
+
Validated axes.
|
|
176
|
+
|
|
177
|
+
Raises
|
|
178
|
+
------
|
|
179
|
+
ValueError
|
|
180
|
+
If axes are not valid.
|
|
181
|
+
"""
|
|
182
|
+
# Validate axes
|
|
183
|
+
check_axes_validity(axes)
|
|
184
|
+
|
|
185
|
+
return axes
|
|
186
|
+
|
|
187
|
+
@field_validator("train_dataloader_params")
|
|
188
|
+
@classmethod
|
|
189
|
+
def shuffle_train_dataloader(
|
|
190
|
+
cls, train_dataloader_params: dict[str, Any]
|
|
191
|
+
) -> dict[str, Any]:
|
|
192
|
+
"""
|
|
193
|
+
Validate that "shuffle" is included in the training dataloader params.
|
|
194
|
+
|
|
195
|
+
A warning will be raised if `shuffle=False`.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
train_dataloader_params : dict of {str: Any}
|
|
200
|
+
The training dataloader parameters.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
dict of {str: Any}
|
|
205
|
+
The validated training dataloader parameters.
|
|
206
|
+
|
|
207
|
+
Raises
|
|
208
|
+
------
|
|
209
|
+
ValueError
|
|
210
|
+
If "shuffle" is not included in the training dataloader params.
|
|
211
|
+
"""
|
|
212
|
+
if "shuffle" not in train_dataloader_params:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Value for 'shuffle' was not included in the `train_dataloader_params`."
|
|
215
|
+
)
|
|
216
|
+
elif ("shuffle" in train_dataloader_params) and (
|
|
217
|
+
not train_dataloader_params["shuffle"]
|
|
218
|
+
):
|
|
219
|
+
warn(
|
|
220
|
+
"Dataloader parameters include `shuffle=False`, this will be passed to "
|
|
221
|
+
"the training dataloader and may lead to lower quality results.",
|
|
222
|
+
stacklevel=1,
|
|
223
|
+
)
|
|
224
|
+
return train_dataloader_params
|
|
225
|
+
|
|
226
|
+
@model_validator(mode="after")
|
|
227
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
228
|
+
"""
|
|
229
|
+
Check that mean and std are either both None, or both specified.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
Self
|
|
234
|
+
Validated data model.
|
|
235
|
+
|
|
236
|
+
Raises
|
|
237
|
+
------
|
|
238
|
+
ValueError
|
|
239
|
+
If std is not None and mean is None.
|
|
240
|
+
"""
|
|
241
|
+
# check that mean and std are either both None, or both specified
|
|
242
|
+
if (self.image_means and not self.image_stds) or (
|
|
243
|
+
self.image_stds and not self.image_means
|
|
244
|
+
):
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"Mean and std must be either both None, or both specified."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
250
|
+
len(self.image_means) != len(self.image_stds)
|
|
251
|
+
):
|
|
252
|
+
raise ValueError("Mean and std must be specified for each input channel.")
|
|
253
|
+
|
|
254
|
+
if (self.target_means and not self.target_stds) or (
|
|
255
|
+
self.target_stds and not self.target_means
|
|
256
|
+
):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"Mean and std must be either both None, or both specified "
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
262
|
+
if len(self.target_means) != len(self.target_stds):
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"Mean and std must be either both None, or both specified for each "
|
|
265
|
+
"target channel."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@model_validator(mode="after")
|
|
271
|
+
def validate_dimensions(self: Self) -> Self:
|
|
272
|
+
"""
|
|
273
|
+
Validate 2D/3D dimensions between axes and patch size.
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
Self
|
|
278
|
+
Validated data model.
|
|
279
|
+
|
|
280
|
+
Raises
|
|
281
|
+
------
|
|
282
|
+
ValueError
|
|
283
|
+
If the patch size dimension is not compatible with the axes.
|
|
284
|
+
"""
|
|
285
|
+
if "Z" in self.axes:
|
|
286
|
+
if (
|
|
287
|
+
hasattr(self.patching, "patch_size")
|
|
288
|
+
and len(self.patching.patch_size) != 3
|
|
289
|
+
):
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"`patch_size` in `patching` must have 3 dimensions if the data is"
|
|
292
|
+
f" 3D, got axes {self.axes})."
|
|
293
|
+
)
|
|
294
|
+
else:
|
|
295
|
+
if (
|
|
296
|
+
hasattr(self.patching, "patch_size")
|
|
297
|
+
and len(self.patching.patch_size) != 2
|
|
298
|
+
):
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"`patch_size` in `patching` must have 2 dimensions if the data is"
|
|
301
|
+
f" 3D, got axes {self.axes})."
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return self
|
|
305
|
+
|
|
306
|
+
def __str__(self) -> str:
|
|
307
|
+
"""
|
|
308
|
+
Pretty string reprensenting the configuration.
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
str
|
|
313
|
+
Pretty string.
|
|
314
|
+
"""
|
|
315
|
+
return pformat(self.model_dump())
|
|
316
|
+
|
|
317
|
+
def _update(self, **kwargs: Any) -> None:
|
|
318
|
+
"""
|
|
319
|
+
Update multiple arguments at once.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
**kwargs : Any
|
|
324
|
+
Keyword arguments to update.
|
|
325
|
+
"""
|
|
326
|
+
self.__dict__.update(kwargs)
|
|
327
|
+
self.__class__.model_validate(self.__dict__)
|
|
328
|
+
|
|
329
|
+
def set_means_and_stds(
|
|
330
|
+
self,
|
|
331
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
332
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
333
|
+
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
334
|
+
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
335
|
+
) -> None:
|
|
336
|
+
"""
|
|
337
|
+
Set mean and standard deviation of the data across channels.
|
|
338
|
+
|
|
339
|
+
This method should be used instead setting the fields directly, as it would
|
|
340
|
+
otherwise trigger a validation error.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
image_means : numpy.ndarray, tuple or list
|
|
345
|
+
Mean values for normalization.
|
|
346
|
+
image_stds : numpy.ndarray, tuple or list
|
|
347
|
+
Standard deviation values for normalization.
|
|
348
|
+
target_means : numpy.ndarray, tuple or list, optional
|
|
349
|
+
Target mean values for normalization, by default ().
|
|
350
|
+
target_stds : numpy.ndarray, tuple or list, optional
|
|
351
|
+
Target standard deviation values for normalization, by default ().
|
|
352
|
+
"""
|
|
353
|
+
# make sure we pass a list
|
|
354
|
+
if image_means is not None:
|
|
355
|
+
image_means = list(image_means)
|
|
356
|
+
if image_stds is not None:
|
|
357
|
+
image_stds = list(image_stds)
|
|
358
|
+
if target_means is not None:
|
|
359
|
+
target_means = list(target_means)
|
|
360
|
+
if target_stds is not None:
|
|
361
|
+
target_stds = list(target_stds)
|
|
362
|
+
|
|
363
|
+
self._update(
|
|
364
|
+
image_means=image_means,
|
|
365
|
+
image_stds=image_stds,
|
|
366
|
+
target_means=target_means,
|
|
367
|
+
target_stds=target_stds,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
371
|
+
# """
|
|
372
|
+
# Set 3D parameters.
|
|
373
|
+
|
|
374
|
+
# Parameters
|
|
375
|
+
# ----------
|
|
376
|
+
# axes : str
|
|
377
|
+
# Axes.
|
|
378
|
+
# patch_size : list of int
|
|
379
|
+
# Patch size.
|
|
380
|
+
# """
|
|
381
|
+
# self._update(axes=axes, patch_size=patch_size)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Patching strategies Pydantic models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"RandomPatchingModel",
|
|
5
|
+
"SequentialPatchingModel",
|
|
6
|
+
"TiledPatchingModel",
|
|
7
|
+
"WholePatchingModel",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from .random_patching_model import RandomPatchingModel
|
|
12
|
+
from .sequential_patching_model import SequentialPatchingModel
|
|
13
|
+
from .tiled_patching_model import TiledPatchingModel
|
|
14
|
+
from .whole_patching_model import WholePatchingModel
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Sequential patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import Field, ValidationInfo, field_validator
|
|
7
|
+
|
|
8
|
+
from ._patched_model import _PatchedModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _OverlappingPatchedModel(_PatchedModel):
|
|
12
|
+
"""Overlapping patching Pydantic model.
|
|
13
|
+
|
|
14
|
+
This model is only used for inheritance and validation purposes.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
patch_size : list of int
|
|
19
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
20
|
+
of 2 and larger than 8.
|
|
21
|
+
overlaps : sequence of int, optional
|
|
22
|
+
The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
23
|
+
applied. The overlaps must be smaller than the patch size in each spatial
|
|
24
|
+
dimension, and the number of dimensions be either 2 or 3.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
overlaps: Optional[Sequence[int]] = Field(
|
|
28
|
+
default=None,
|
|
29
|
+
min_length=2,
|
|
30
|
+
max_length=3,
|
|
31
|
+
)
|
|
32
|
+
"""The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
33
|
+
applied. The overlaps must be smaller than the patch size in each spatial dimension,
|
|
34
|
+
and the number of dimensions be either 2 or 3.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@field_validator("overlaps")
|
|
38
|
+
@classmethod
|
|
39
|
+
def overlap_smaller_than_patch_size(
|
|
40
|
+
cls, overlaps: Optional[Sequence[int]], values: ValidationInfo
|
|
41
|
+
) -> Optional[Sequence[int]]:
|
|
42
|
+
"""
|
|
43
|
+
Validate overlap.
|
|
44
|
+
|
|
45
|
+
Overlaps must be smaller than the patch size in each spatial dimension.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
overlaps : Sequence of int
|
|
50
|
+
Overlap in each dimension.
|
|
51
|
+
values : ValidationInfo
|
|
52
|
+
Dictionary of values.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
Sequence of int
|
|
57
|
+
Validated overlap.
|
|
58
|
+
"""
|
|
59
|
+
if overlaps is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
patch_size = values.data["patch_size"]
|
|
63
|
+
|
|
64
|
+
if len(overlaps) != len(patch_size):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Overlaps must have the same number of dimensions as the patch size. "
|
|
67
|
+
f"Got {len(overlaps)} dimensions for overlaps and {len(patch_size)} "
|
|
68
|
+
f"dimensions for patch size."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if any(o >= p for o, p in zip(overlaps, patch_size, strict=False)):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Overlap must be smaller than the patch size, got {overlaps} versus "
|
|
74
|
+
f"{patch_size}."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return overlaps
|
|
78
|
+
|
|
79
|
+
@field_validator("overlaps")
|
|
80
|
+
@classmethod
|
|
81
|
+
def overlap_even(cls, overlaps: Optional[Sequence[int]]) -> Optional[Sequence[int]]:
|
|
82
|
+
"""
|
|
83
|
+
Validate overlaps.
|
|
84
|
+
|
|
85
|
+
Overlap must be even.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
overlaps : Sequence of int
|
|
90
|
+
Overlaps.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Sequence of int
|
|
95
|
+
Validated overlap.
|
|
96
|
+
"""
|
|
97
|
+
if overlaps is None:
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
if any(o % 2 != 0 for o in overlaps):
|
|
101
|
+
raise ValueError(f"Overlaps must be even, got {overlaps}.")
|
|
102
|
+
|
|
103
|
+
return overlaps
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Generic patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from careamics.config.validators import patch_size_ge_than_8_power_of_2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _PatchedModel(BaseModel):
|
|
11
|
+
"""Generic patching Pydantic model.
|
|
12
|
+
|
|
13
|
+
This model is only used for inheritance and validation purposes.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
extra="ignore", # default behaviour, make it explicit
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
"""The name of the patching strategy."""
|
|
22
|
+
|
|
23
|
+
patch_size: Sequence[int] = Field(..., min_length=2, max_length=3)
|
|
24
|
+
"""The size of the patch in each spatial dimensions, each patch size must be a power
|
|
25
|
+
of 2 and larger than 8."""
|
|
26
|
+
|
|
27
|
+
@field_validator("patch_size")
|
|
28
|
+
@classmethod
|
|
29
|
+
def all_elements_power_of_2_minimum_8(
|
|
30
|
+
cls, patch_list: Sequence[int]
|
|
31
|
+
) -> Sequence[int]:
|
|
32
|
+
"""
|
|
33
|
+
Validate patch size.
|
|
34
|
+
|
|
35
|
+
Patch size must be powers of 2 and minimum 8.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
patch_list : Sequence of int
|
|
40
|
+
Patch size.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Sequence of int
|
|
45
|
+
Validated patch size.
|
|
46
|
+
|
|
47
|
+
Raises
|
|
48
|
+
------
|
|
49
|
+
ValueError
|
|
50
|
+
If the patch size is smaller than 8.
|
|
51
|
+
ValueError
|
|
52
|
+
If the patch size is not a power of 2.
|
|
53
|
+
"""
|
|
54
|
+
patch_size_ge_than_8_power_of_2(patch_list)
|
|
55
|
+
|
|
56
|
+
return patch_list
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Random patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from ._patched_model import _PatchedModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RandomPatchingModel(_PatchedModel):
|
|
9
|
+
"""Random patching Pydantic model.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
name : "random"
|
|
14
|
+
The name of the patching strategy.
|
|
15
|
+
patch_size : sequence of int
|
|
16
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
17
|
+
of 2 and larger than 8.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
name: Literal["random"] = "random"
|
|
21
|
+
"""The name of the patching strategy."""
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Sequential patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from ._overlapping_patched_model import _OverlappingPatchedModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SequentialPatchingModel(_OverlappingPatchedModel):
|
|
9
|
+
"""Sequential patching Pydantic model.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
name : "sequential"
|
|
14
|
+
The name of the patching strategy.
|
|
15
|
+
patch_size : sequence of int
|
|
16
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
17
|
+
of 2 and larger than 8.
|
|
18
|
+
overlaps : list of int, optional
|
|
19
|
+
The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
20
|
+
applied. The overlaps must be smaller than the patch size in each spatial
|
|
21
|
+
dimension, and the number of dimensions be either 2 or 3.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name: Literal["sequential"] = "sequential"
|
|
25
|
+
"""The name of the patching strategy."""
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Tiled patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
from ._overlapping_patched_model import _OverlappingPatchedModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO with UNet tiling must obey different rules than sequential tiling
|
|
12
|
+
# - needs to validated at the level of the configuration
|
|
13
|
+
class TiledPatchingModel(_OverlappingPatchedModel):
|
|
14
|
+
"""Tiled patching Pydantic model.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
name : "tiled"
|
|
19
|
+
The name of the patching strategy.
|
|
20
|
+
patch_size : sequence of int
|
|
21
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
22
|
+
of 2 and larger than 8.
|
|
23
|
+
overlaps : sequence of int
|
|
24
|
+
The overlaps between patches in each spatial dimension. The overlaps must be
|
|
25
|
+
smaller than the patch size in each spatial dimension, and the number of
|
|
26
|
+
dimensions be either 2 or 3.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
name: Literal["tiled"] = "tiled"
|
|
30
|
+
"""The name of the patching strategy."""
|
|
31
|
+
|
|
32
|
+
overlaps: Sequence[int] = Field(
|
|
33
|
+
...,
|
|
34
|
+
min_length=2,
|
|
35
|
+
max_length=3,
|
|
36
|
+
)
|
|
37
|
+
"""The overlaps between patches in each spatial dimension. The overlaps must be
|
|
38
|
+
smaller than the patch size in each spatial dimension, and the number of dimensions
|
|
39
|
+
be either 2 or 3.
|
|
40
|
+
"""
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Whole image patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WholePatchingModel(BaseModel):
|
|
9
|
+
"""Whole image patching Pydantic model."""
|
|
10
|
+
|
|
11
|
+
name: Literal["whole"] = "whole"
|
|
12
|
+
"""The name of the patching strategy."""
|
|
@@ -15,8 +15,8 @@ class InferenceConfig(BaseModel):
|
|
|
15
15
|
|
|
16
16
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
17
17
|
|
|
18
|
-
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
19
|
-
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
|
|
18
|
+
data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
|
|
19
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
|
|
20
20
|
|
|
21
21
|
tile_size: Optional[Union[list[int]]] = Field(
|
|
22
22
|
default=None, min_length=2, max_length=3
|
|
@@ -171,7 +171,10 @@ class InferenceConfig(BaseModel):
|
|
|
171
171
|
f"{self.axes} (got {self.tile_overlap})."
|
|
172
172
|
)
|
|
173
173
|
|
|
174
|
-
if any(
|
|
174
|
+
if any(
|
|
175
|
+
(i >= j)
|
|
176
|
+
for i, j in zip(self.tile_overlap, self.tile_size, strict=False)
|
|
177
|
+
):
|
|
175
178
|
raise ValueError("Tile overlap must be smaller than tile size.")
|
|
176
179
|
|
|
177
180
|
return self
|