careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (74) hide show
  1. careamics/careamist.py +4 -3
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +47 -1
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +3 -3
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/support/supported_data.py +7 -0
  20. careamics/config/support/supported_patching_strategies.py +22 -0
  21. careamics/config/validators/validator_utils.py +4 -3
  22. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  23. careamics/dataset/in_memory_dataset.py +2 -1
  24. careamics/dataset/iterable_dataset.py +2 -2
  25. careamics/dataset/iterable_pred_dataset.py +2 -2
  26. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  27. careamics/dataset/patching/patching.py +3 -2
  28. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  29. careamics/dataset/tiling/tiled_patching.py +2 -1
  30. careamics/dataset_ng/dataset.py +46 -50
  31. careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
  32. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
  33. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
  34. careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
  35. careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
  36. careamics/dataset_ng/factory.py +58 -15
  37. careamics/dataset_ng/legacy_interoperability.py +3 -1
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
  39. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
  40. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  41. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  42. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
  43. careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
  44. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  45. careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
  46. careamics/file_io/read/get_func.py +2 -1
  47. careamics/lightning/dataset_ng/__init__.py +1 -0
  48. careamics/lightning/dataset_ng/data_module.py +218 -28
  49. careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
  50. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
  51. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
  52. careamics/lightning/lightning_module.py +2 -1
  53. careamics/lightning/predict_data_module.py +2 -1
  54. careamics/lightning/train_data_module.py +2 -1
  55. careamics/losses/loss_factory.py +2 -1
  56. careamics/lvae_training/dataset/multicrop_dset.py +1 -1
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +1 -1
  59. careamics/model_io/bmz_io.py +1 -1
  60. careamics/model_io/model_io_utils.py +2 -2
  61. careamics/models/activation.py +2 -1
  62. careamics/models/unet.py +16 -10
  63. careamics/prediction_utils/prediction_outputs.py +1 -1
  64. careamics/prediction_utils/stitch_prediction.py +1 -1
  65. careamics/transforms/n2v_manipulate_torch.py +15 -9
  66. careamics/transforms/pixel_manipulation_torch.py +59 -92
  67. careamics/utils/lightning_utils.py +2 -2
  68. careamics/utils/metrics.py +2 -1
  69. careamics/utils/torch_utils.py +23 -0
  70. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
  71. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
  72. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
  73. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
  74. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,381 @@
1
+ """Data configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from pprint import pformat
7
+ from typing import Annotated, Any, Literal, Optional, Union
8
+ from warnings import warn
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+ from pydantic import (
13
+ BaseModel,
14
+ ConfigDict,
15
+ Field,
16
+ PlainSerializer,
17
+ field_validator,
18
+ model_validator,
19
+ )
20
+ from typing_extensions import Self
21
+
22
+ from ..transformations import XYFlipModel, XYRandomRotate90Model
23
+ from ..validators import check_axes_validity
24
+ from .patching_strategies import (
25
+ RandomPatchingModel,
26
+ TiledPatchingModel,
27
+ WholePatchingModel,
28
+ )
29
+
30
+ # TODO: Validate the specific sizes of tiles and overlaps given UNet constraints
31
+ # - needs to be done in the Configuration
32
+ # - patches and overlaps sizes must also be checked against dimensionality
33
+
34
+ # TODO: is 3D updated anywhere in the code in CAREamist/downstream?
35
+ # - this will be important when swapping the data config in Configuration
36
+ # - `set_3D` currently not implemented here
37
+ # TODO: we can't tell that the patching strategy is correct
38
+ # - or is the responsibility of the creator (e.g. conveneince functions)
39
+
40
+
41
+ def np_float_to_scientific_str(x: float) -> str:
42
+ """Return a string scientific representation of a float.
43
+
44
+ In particular, this method is used to serialize floats to strings, allowing
45
+ numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
46
+
47
+ Parameters
48
+ ----------
49
+ x : float
50
+ Input value.
51
+
52
+ Returns
53
+ -------
54
+ str
55
+ Scientific string representation of the input value.
56
+ """
57
+ return np.format_float_scientific(x, precision=7)
58
+
59
+
60
+ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
61
+ """Annotated float type, used to serialize floats to strings."""
62
+
63
+ PatchingStrategies = Union[
64
+ RandomPatchingModel,
65
+ # SequentialPatchingModel, # not supported yet
66
+ TiledPatchingModel,
67
+ WholePatchingModel,
68
+ ]
69
+ """Patching strategies."""
70
+
71
+
72
+ class NGDataConfig(BaseModel):
73
+ """Next-Generation Dataset configuration.
74
+
75
+ NGDataConfig are used for both training and prediction, with the patching strategy
76
+ determining how the data is processed. Note that `random` is the only patching
77
+ strategy compatible with training, while `tiled` and `whole` are only used for
78
+ prediction.
79
+
80
+ If std is specified, mean must be specified as well. Note that setting the std first
81
+ and then the mean (if they were both `None` before) will raise a validation error.
82
+ Prefer instead `set_means_and_stds` to set both at once. Means and stds are expected
83
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
84
+ of the target could be different from the input data.
85
+
86
+ All supported transforms are defined in the SupportedTransform enum.
87
+ """
88
+
89
+ # Pydantic class configuration
90
+ model_config = ConfigDict(
91
+ validate_assignment=True,
92
+ )
93
+
94
+ # Dataset configuration
95
+ data_type: Literal["array", "tiff", "zarr", "custom"]
96
+ """Type of input data."""
97
+
98
+ axes: str
99
+ """Axes of the data, as defined in SupportedAxes."""
100
+
101
+ patching: PatchingStrategies = Field(..., discriminator="name")
102
+ """Patching strategy to use. Note that `random` is the only supported strategy for
103
+ training, while `tiled` and `whole` are only used for prediction."""
104
+
105
+ # Optional fields
106
+ batch_size: int = Field(default=1, ge=1, validate_default=True)
107
+ """Batch size for training."""
108
+
109
+ image_means: Optional[list[Float]] = Field(
110
+ default=None, min_length=0, max_length=32
111
+ )
112
+ """Means of the data across channels, used for normalization."""
113
+
114
+ image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
115
+ """Standard deviations of the data across channels, used for normalization."""
116
+
117
+ target_means: Optional[list[Float]] = Field(
118
+ default=None, min_length=0, max_length=32
119
+ )
120
+ """Means of the target data across channels, used for normalization."""
121
+
122
+ target_stds: Optional[list[Float]] = Field(
123
+ default=None, min_length=0, max_length=32
124
+ )
125
+ """Standard deviations of the target data across channels, used for
126
+ normalization."""
127
+
128
+ transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
129
+ default=(
130
+ XYFlipModel(),
131
+ XYRandomRotate90Model(),
132
+ ),
133
+ validate_default=True,
134
+ )
135
+ """List of transformations to apply to the data, available transforms are defined
136
+ in SupportedTransform."""
137
+
138
+ train_dataloader_params: dict[str, Any] = Field(
139
+ default={"shuffle": True}, validate_default=True
140
+ )
141
+ """Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
142
+ should include the `shuffle` key, which is set to `True` by default. We strongly
143
+ recommend to keep it as `True` to ensure the best training results."""
144
+
145
+ val_dataloader_params: dict[str, Any] = Field(default={})
146
+ """Dictionary of PyTorch validation dataloader parameters."""
147
+
148
+ test_dataloader_params: dict[str, Any] = Field(default={})
149
+ """Dictionary of PyTorch test dataloader parameters."""
150
+
151
+ seed: Optional[int] = Field(default=None, gt=0)
152
+ """Random seed for reproducibility."""
153
+
154
+ @field_validator("axes")
155
+ @classmethod
156
+ def axes_valid(cls, axes: str) -> str:
157
+ """
158
+ Validate axes.
159
+
160
+ Axes must:
161
+ - be a combination of 'STCZYX'
162
+ - not contain duplicates
163
+ - contain at least 2 contiguous axes: X and Y
164
+ - contain at most 4 axes
165
+ - not contain both S and T axes
166
+
167
+ Parameters
168
+ ----------
169
+ axes : str
170
+ Axes to validate.
171
+
172
+ Returns
173
+ -------
174
+ str
175
+ Validated axes.
176
+
177
+ Raises
178
+ ------
179
+ ValueError
180
+ If axes are not valid.
181
+ """
182
+ # Validate axes
183
+ check_axes_validity(axes)
184
+
185
+ return axes
186
+
187
+ @field_validator("train_dataloader_params")
188
+ @classmethod
189
+ def shuffle_train_dataloader(
190
+ cls, train_dataloader_params: dict[str, Any]
191
+ ) -> dict[str, Any]:
192
+ """
193
+ Validate that "shuffle" is included in the training dataloader params.
194
+
195
+ A warning will be raised if `shuffle=False`.
196
+
197
+ Parameters
198
+ ----------
199
+ train_dataloader_params : dict of {str: Any}
200
+ The training dataloader parameters.
201
+
202
+ Returns
203
+ -------
204
+ dict of {str: Any}
205
+ The validated training dataloader parameters.
206
+
207
+ Raises
208
+ ------
209
+ ValueError
210
+ If "shuffle" is not included in the training dataloader params.
211
+ """
212
+ if "shuffle" not in train_dataloader_params:
213
+ raise ValueError(
214
+ "Value for 'shuffle' was not included in the `train_dataloader_params`."
215
+ )
216
+ elif ("shuffle" in train_dataloader_params) and (
217
+ not train_dataloader_params["shuffle"]
218
+ ):
219
+ warn(
220
+ "Dataloader parameters include `shuffle=False`, this will be passed to "
221
+ "the training dataloader and may lead to lower quality results.",
222
+ stacklevel=1,
223
+ )
224
+ return train_dataloader_params
225
+
226
+ @model_validator(mode="after")
227
+ def std_only_with_mean(self: Self) -> Self:
228
+ """
229
+ Check that mean and std are either both None, or both specified.
230
+
231
+ Returns
232
+ -------
233
+ Self
234
+ Validated data model.
235
+
236
+ Raises
237
+ ------
238
+ ValueError
239
+ If std is not None and mean is None.
240
+ """
241
+ # check that mean and std are either both None, or both specified
242
+ if (self.image_means and not self.image_stds) or (
243
+ self.image_stds and not self.image_means
244
+ ):
245
+ raise ValueError(
246
+ "Mean and std must be either both None, or both specified."
247
+ )
248
+
249
+ elif (self.image_means is not None and self.image_stds is not None) and (
250
+ len(self.image_means) != len(self.image_stds)
251
+ ):
252
+ raise ValueError("Mean and std must be specified for each input channel.")
253
+
254
+ if (self.target_means and not self.target_stds) or (
255
+ self.target_stds and not self.target_means
256
+ ):
257
+ raise ValueError(
258
+ "Mean and std must be either both None, or both specified "
259
+ )
260
+
261
+ elif self.target_means is not None and self.target_stds is not None:
262
+ if len(self.target_means) != len(self.target_stds):
263
+ raise ValueError(
264
+ "Mean and std must be either both None, or both specified for each "
265
+ "target channel."
266
+ )
267
+
268
+ return self
269
+
270
+ @model_validator(mode="after")
271
+ def validate_dimensions(self: Self) -> Self:
272
+ """
273
+ Validate 2D/3D dimensions between axes and patch size.
274
+
275
+ Returns
276
+ -------
277
+ Self
278
+ Validated data model.
279
+
280
+ Raises
281
+ ------
282
+ ValueError
283
+ If the patch size dimension is not compatible with the axes.
284
+ """
285
+ if "Z" in self.axes:
286
+ if (
287
+ hasattr(self.patching, "patch_size")
288
+ and len(self.patching.patch_size) != 3
289
+ ):
290
+ raise ValueError(
291
+ f"`patch_size` in `patching` must have 3 dimensions if the data is"
292
+ f" 3D, got axes {self.axes})."
293
+ )
294
+ else:
295
+ if (
296
+ hasattr(self.patching, "patch_size")
297
+ and len(self.patching.patch_size) != 2
298
+ ):
299
+ raise ValueError(
300
+ f"`patch_size` in `patching` must have 2 dimensions if the data is"
301
+ f" 3D, got axes {self.axes})."
302
+ )
303
+
304
+ return self
305
+
306
+ def __str__(self) -> str:
307
+ """
308
+ Pretty string reprensenting the configuration.
309
+
310
+ Returns
311
+ -------
312
+ str
313
+ Pretty string.
314
+ """
315
+ return pformat(self.model_dump())
316
+
317
+ def _update(self, **kwargs: Any) -> None:
318
+ """
319
+ Update multiple arguments at once.
320
+
321
+ Parameters
322
+ ----------
323
+ **kwargs : Any
324
+ Keyword arguments to update.
325
+ """
326
+ self.__dict__.update(kwargs)
327
+ self.__class__.model_validate(self.__dict__)
328
+
329
+ def set_means_and_stds(
330
+ self,
331
+ image_means: Union[NDArray, tuple, list, None],
332
+ image_stds: Union[NDArray, tuple, list, None],
333
+ target_means: Optional[Union[NDArray, tuple, list, None]] = None,
334
+ target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
335
+ ) -> None:
336
+ """
337
+ Set mean and standard deviation of the data across channels.
338
+
339
+ This method should be used instead setting the fields directly, as it would
340
+ otherwise trigger a validation error.
341
+
342
+ Parameters
343
+ ----------
344
+ image_means : numpy.ndarray, tuple or list
345
+ Mean values for normalization.
346
+ image_stds : numpy.ndarray, tuple or list
347
+ Standard deviation values for normalization.
348
+ target_means : numpy.ndarray, tuple or list, optional
349
+ Target mean values for normalization, by default ().
350
+ target_stds : numpy.ndarray, tuple or list, optional
351
+ Target standard deviation values for normalization, by default ().
352
+ """
353
+ # make sure we pass a list
354
+ if image_means is not None:
355
+ image_means = list(image_means)
356
+ if image_stds is not None:
357
+ image_stds = list(image_stds)
358
+ if target_means is not None:
359
+ target_means = list(target_means)
360
+ if target_stds is not None:
361
+ target_stds = list(target_stds)
362
+
363
+ self._update(
364
+ image_means=image_means,
365
+ image_stds=image_stds,
366
+ target_means=target_means,
367
+ target_stds=target_stds,
368
+ )
369
+
370
+ # def set_3D(self, axes: str, patch_size: list[int]) -> None:
371
+ # """
372
+ # Set 3D parameters.
373
+
374
+ # Parameters
375
+ # ----------
376
+ # axes : str
377
+ # Axes.
378
+ # patch_size : list of int
379
+ # Patch size.
380
+ # """
381
+ # self._update(axes=axes, patch_size=patch_size)
@@ -0,0 +1,14 @@
1
+ """Patching strategies Pydantic models."""
2
+
3
+ __all__ = [
4
+ "RandomPatchingModel",
5
+ "SequentialPatchingModel",
6
+ "TiledPatchingModel",
7
+ "WholePatchingModel",
8
+ ]
9
+
10
+
11
+ from .random_patching_model import RandomPatchingModel
12
+ from .sequential_patching_model import SequentialPatchingModel
13
+ from .tiled_patching_model import TiledPatchingModel
14
+ from .whole_patching_model import WholePatchingModel
@@ -0,0 +1,103 @@
1
+ """Sequential patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Optional
5
+
6
+ from pydantic import Field, ValidationInfo, field_validator
7
+
8
+ from ._patched_model import _PatchedModel
9
+
10
+
11
+ class _OverlappingPatchedModel(_PatchedModel):
12
+ """Overlapping patching Pydantic model.
13
+
14
+ This model is only used for inheritance and validation purposes.
15
+
16
+ Attributes
17
+ ----------
18
+ patch_size : list of int
19
+ The size of the patch in each spatial dimension, each patch size must be a power
20
+ of 2 and larger than 8.
21
+ overlaps : sequence of int, optional
22
+ The overlaps between patches in each spatial dimension. If `None`, no overlap is
23
+ applied. The overlaps must be smaller than the patch size in each spatial
24
+ dimension, and the number of dimensions be either 2 or 3.
25
+ """
26
+
27
+ overlaps: Optional[Sequence[int]] = Field(
28
+ default=None,
29
+ min_length=2,
30
+ max_length=3,
31
+ )
32
+ """The overlaps between patches in each spatial dimension. If `None`, no overlap is
33
+ applied. The overlaps must be smaller than the patch size in each spatial dimension,
34
+ and the number of dimensions be either 2 or 3.
35
+ """
36
+
37
+ @field_validator("overlaps")
38
+ @classmethod
39
+ def overlap_smaller_than_patch_size(
40
+ cls, overlaps: Optional[Sequence[int]], values: ValidationInfo
41
+ ) -> Optional[Sequence[int]]:
42
+ """
43
+ Validate overlap.
44
+
45
+ Overlaps must be smaller than the patch size in each spatial dimension.
46
+
47
+ Parameters
48
+ ----------
49
+ overlaps : Sequence of int
50
+ Overlap in each dimension.
51
+ values : ValidationInfo
52
+ Dictionary of values.
53
+
54
+ Returns
55
+ -------
56
+ Sequence of int
57
+ Validated overlap.
58
+ """
59
+ if overlaps is None:
60
+ return None
61
+
62
+ patch_size = values.data["patch_size"]
63
+
64
+ if len(overlaps) != len(patch_size):
65
+ raise ValueError(
66
+ f"Overlaps must have the same number of dimensions as the patch size. "
67
+ f"Got {len(overlaps)} dimensions for overlaps and {len(patch_size)} "
68
+ f"dimensions for patch size."
69
+ )
70
+
71
+ if any(o >= p for o, p in zip(overlaps, patch_size, strict=False)):
72
+ raise ValueError(
73
+ f"Overlap must be smaller than the patch size, got {overlaps} versus "
74
+ f"{patch_size}."
75
+ )
76
+
77
+ return overlaps
78
+
79
+ @field_validator("overlaps")
80
+ @classmethod
81
+ def overlap_even(cls, overlaps: Optional[Sequence[int]]) -> Optional[Sequence[int]]:
82
+ """
83
+ Validate overlaps.
84
+
85
+ Overlap must be even.
86
+
87
+ Parameters
88
+ ----------
89
+ overlaps : Sequence of int
90
+ Overlaps.
91
+
92
+ Returns
93
+ -------
94
+ Sequence of int
95
+ Validated overlap.
96
+ """
97
+ if overlaps is None:
98
+ return None
99
+
100
+ if any(o % 2 != 0 for o in overlaps):
101
+ raise ValueError(f"Overlaps must be even, got {overlaps}.")
102
+
103
+ return overlaps
@@ -0,0 +1,56 @@
1
+ """Generic patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
+
7
+ from careamics.config.validators import patch_size_ge_than_8_power_of_2
8
+
9
+
10
+ class _PatchedModel(BaseModel):
11
+ """Generic patching Pydantic model.
12
+
13
+ This model is only used for inheritance and validation purposes.
14
+ """
15
+
16
+ model_config = ConfigDict(
17
+ extra="ignore", # default behaviour, make it explicit
18
+ )
19
+
20
+ name: str
21
+ """The name of the patching strategy."""
22
+
23
+ patch_size: Sequence[int] = Field(..., min_length=2, max_length=3)
24
+ """The size of the patch in each spatial dimensions, each patch size must be a power
25
+ of 2 and larger than 8."""
26
+
27
+ @field_validator("patch_size")
28
+ @classmethod
29
+ def all_elements_power_of_2_minimum_8(
30
+ cls, patch_list: Sequence[int]
31
+ ) -> Sequence[int]:
32
+ """
33
+ Validate patch size.
34
+
35
+ Patch size must be powers of 2 and minimum 8.
36
+
37
+ Parameters
38
+ ----------
39
+ patch_list : Sequence of int
40
+ Patch size.
41
+
42
+ Returns
43
+ -------
44
+ Sequence of int
45
+ Validated patch size.
46
+
47
+ Raises
48
+ ------
49
+ ValueError
50
+ If the patch size is smaller than 8.
51
+ ValueError
52
+ If the patch size is not a power of 2.
53
+ """
54
+ patch_size_ge_than_8_power_of_2(patch_list)
55
+
56
+ return patch_list
@@ -0,0 +1,21 @@
1
+ """Random patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from ._patched_model import _PatchedModel
6
+
7
+
8
+ class RandomPatchingModel(_PatchedModel):
9
+ """Random patching Pydantic model.
10
+
11
+ Attributes
12
+ ----------
13
+ name : "random"
14
+ The name of the patching strategy.
15
+ patch_size : sequence of int
16
+ The size of the patch in each spatial dimension, each patch size must be a power
17
+ of 2 and larger than 8.
18
+ """
19
+
20
+ name: Literal["random"] = "random"
21
+ """The name of the patching strategy."""
@@ -0,0 +1,25 @@
1
+ """Sequential patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from ._overlapping_patched_model import _OverlappingPatchedModel
6
+
7
+
8
+ class SequentialPatchingModel(_OverlappingPatchedModel):
9
+ """Sequential patching Pydantic model.
10
+
11
+ Attributes
12
+ ----------
13
+ name : "sequential"
14
+ The name of the patching strategy.
15
+ patch_size : sequence of int
16
+ The size of the patch in each spatial dimension, each patch size must be a power
17
+ of 2 and larger than 8.
18
+ overlaps : list of int, optional
19
+ The overlaps between patches in each spatial dimension. If `None`, no overlap is
20
+ applied. The overlaps must be smaller than the patch size in each spatial
21
+ dimension, and the number of dimensions be either 2 or 3.
22
+ """
23
+
24
+ name: Literal["sequential"] = "sequential"
25
+ """The name of the patching strategy."""
@@ -0,0 +1,40 @@
1
+ """Tiled patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Literal
5
+
6
+ from pydantic import Field
7
+
8
+ from ._overlapping_patched_model import _OverlappingPatchedModel
9
+
10
+
11
+ # TODO with UNet tiling must obey different rules than sequential tiling
12
+ # - needs to validated at the level of the configuration
13
+ class TiledPatchingModel(_OverlappingPatchedModel):
14
+ """Tiled patching Pydantic model.
15
+
16
+ Attributes
17
+ ----------
18
+ name : "tiled"
19
+ The name of the patching strategy.
20
+ patch_size : sequence of int
21
+ The size of the patch in each spatial dimension, each patch size must be a power
22
+ of 2 and larger than 8.
23
+ overlaps : sequence of int
24
+ The overlaps between patches in each spatial dimension. The overlaps must be
25
+ smaller than the patch size in each spatial dimension, and the number of
26
+ dimensions be either 2 or 3.
27
+ """
28
+
29
+ name: Literal["tiled"] = "tiled"
30
+ """The name of the patching strategy."""
31
+
32
+ overlaps: Sequence[int] = Field(
33
+ ...,
34
+ min_length=2,
35
+ max_length=3,
36
+ )
37
+ """The overlaps between patches in each spatial dimension. The overlaps must be
38
+ smaller than the patch size in each spatial dimension, and the number of dimensions
39
+ be either 2 or 3.
40
+ """
@@ -0,0 +1,12 @@
1
+ """Whole image patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class WholePatchingModel(BaseModel):
9
+ """Whole image patching Pydantic model."""
10
+
11
+ name: Literal["whole"] = "whole"
12
+ """The name of the patching strategy."""
@@ -15,8 +15,8 @@ class InferenceConfig(BaseModel):
15
15
 
16
16
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
17
17
 
18
- data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
19
- """Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
18
+ data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
19
+ """Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
20
20
 
21
21
  tile_size: Optional[Union[list[int]]] = Field(
22
22
  default=None, min_length=2, max_length=3
@@ -171,7 +171,10 @@ class InferenceConfig(BaseModel):
171
171
  f"{self.axes} (got {self.tile_overlap})."
172
172
  )
173
173
 
174
- if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
174
+ if any(
175
+ (i >= j)
176
+ for i, j in zip(self.tile_overlap, self.tile_size, strict=False)
177
+ ):
175
178
  raise ValueError("Tile overlap must be smaller than tile size.")
176
179
 
177
180
  return self