careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
"""Training configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
ConfigDict,
|
|
9
|
+
Field,
|
|
10
|
+
FieldValidationInfo,
|
|
11
|
+
field_validator,
|
|
12
|
+
model_validator,
|
|
13
|
+
)
|
|
14
|
+
from torch import optim
|
|
15
|
+
|
|
16
|
+
from .config_filter import remove_default_optionals
|
|
17
|
+
from .torch_optim import TorchLRScheduler, TorchOptimizer, get_parameters
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Optimizer(BaseModel):
|
|
21
|
+
"""
|
|
22
|
+
Torch optimizer.
|
|
23
|
+
|
|
24
|
+
Only parameters supported by the corresponding torch optimizer will be taken
|
|
25
|
+
into account. For more details, check:
|
|
26
|
+
https://pytorch.org/docs/stable/optim.html#algorithms
|
|
27
|
+
|
|
28
|
+
Note that mandatory parameters (see the specific Optimizer signature in the
|
|
29
|
+
link above) must be provided. For example, SGD requires `lr`.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
name : TorchOptimizer
|
|
34
|
+
Name of the optimizer.
|
|
35
|
+
parameters : dict
|
|
36
|
+
Parameters of the optimizer (see torch documentation).
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Pydantic class configuration
|
|
40
|
+
model_config = ConfigDict(
|
|
41
|
+
use_enum_values=True,
|
|
42
|
+
validate_assignment=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Mandatory field
|
|
46
|
+
name: TorchOptimizer
|
|
47
|
+
|
|
48
|
+
# Optional parameters
|
|
49
|
+
parameters: dict = {}
|
|
50
|
+
|
|
51
|
+
@field_validator("parameters")
|
|
52
|
+
def filter_parameters(cls, user_params: dict, values: FieldValidationInfo) -> Dict:
|
|
53
|
+
"""
|
|
54
|
+
Validate optimizer parameters.
|
|
55
|
+
|
|
56
|
+
This method filters out unknown parameters, given the optimizer name.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
user_params : dict
|
|
61
|
+
Parameters passed on to the torch optimizer.
|
|
62
|
+
values : FieldValidationInfo
|
|
63
|
+
Pydantic field validation info, used to get the optimizer name.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Dict
|
|
68
|
+
Filtered optimizer parameters.
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
------
|
|
72
|
+
ValueError
|
|
73
|
+
If the optimizer name is not specified.
|
|
74
|
+
"""
|
|
75
|
+
if "name" in values.data:
|
|
76
|
+
optimizer_name = values.data["name"]
|
|
77
|
+
|
|
78
|
+
# retrieve the corresponding optimizer class
|
|
79
|
+
optimizer_class = getattr(optim, optimizer_name)
|
|
80
|
+
|
|
81
|
+
# filter the user parameters according to the optimizer's signature
|
|
82
|
+
return get_parameters(optimizer_class, user_params)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Cannot validate optimizer parameters without `name`, check that it "
|
|
86
|
+
"has correctly been specified."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def sgd_lr_parameter(cls, optimizer: Optimizer) -> Optimizer:
|
|
91
|
+
"""
|
|
92
|
+
Check that SGD optimizer has the mandatory `lr` parameter specified.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
optimizer : Optimizer
|
|
97
|
+
Optimizer to validate.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
Optimizer
|
|
102
|
+
Validated optimizer.
|
|
103
|
+
|
|
104
|
+
Raises
|
|
105
|
+
------
|
|
106
|
+
ValueError
|
|
107
|
+
If the optimizer is SGD and the lr parameter is not specified.
|
|
108
|
+
"""
|
|
109
|
+
if optimizer.name == TorchOptimizer.SGD and "lr" not in optimizer.parameters:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"SGD optimizer requires `lr` parameter, check that it has correctly "
|
|
112
|
+
"been specified in `parameters`."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return optimizer
|
|
116
|
+
|
|
117
|
+
def model_dump(
|
|
118
|
+
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
119
|
+
) -> Dict:
|
|
120
|
+
"""
|
|
121
|
+
Override model_dump method.
|
|
122
|
+
|
|
123
|
+
The purpose of this method is to ensure smooth export to yaml. It
|
|
124
|
+
includes:
|
|
125
|
+
- removing entries with None value.
|
|
126
|
+
- removing optional values if they have the default value.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
exclude_optionals : bool, optional
|
|
131
|
+
Whether to exclude optional arguments if they are default, by default True.
|
|
132
|
+
*args : List
|
|
133
|
+
Positional arguments, unused.
|
|
134
|
+
**kwargs : Dict
|
|
135
|
+
Keyword arguments, unused.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
dict
|
|
140
|
+
Dictionary containing the model parameters.
|
|
141
|
+
"""
|
|
142
|
+
dictionary = super().model_dump(exclude_none=True)
|
|
143
|
+
|
|
144
|
+
if exclude_optionals:
|
|
145
|
+
# remove optional arguments if they are default
|
|
146
|
+
default_optionals: dict = {"parameters": {}}
|
|
147
|
+
|
|
148
|
+
remove_default_optionals(dictionary, default_optionals)
|
|
149
|
+
|
|
150
|
+
return dictionary
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class LrScheduler(BaseModel):
|
|
154
|
+
"""
|
|
155
|
+
Torch learning rate scheduler.
|
|
156
|
+
|
|
157
|
+
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
158
|
+
into account. For more details, check:
|
|
159
|
+
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
|
160
|
+
|
|
161
|
+
Note that mandatory parameters (see the specific LrScheduler signature in the
|
|
162
|
+
link above) must be provided. For example, StepLR requires `step_size`.
|
|
163
|
+
|
|
164
|
+
Attributes
|
|
165
|
+
----------
|
|
166
|
+
name : TorchLRScheduler
|
|
167
|
+
Name of the learning rate scheduler.
|
|
168
|
+
parameters : dict
|
|
169
|
+
Parameters of the learning rate scheduler (see torch documentation).
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
# Pydantic class configuration
|
|
173
|
+
model_config = ConfigDict(
|
|
174
|
+
use_enum_values=True,
|
|
175
|
+
validate_assignment=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Mandatory field
|
|
179
|
+
name: TorchLRScheduler
|
|
180
|
+
|
|
181
|
+
# Optional parameters
|
|
182
|
+
parameters: dict = {}
|
|
183
|
+
|
|
184
|
+
@field_validator("parameters")
|
|
185
|
+
def filter_parameters(cls, user_params: dict, values: FieldValidationInfo) -> Dict:
|
|
186
|
+
"""
|
|
187
|
+
Validate lr scheduler parameters.
|
|
188
|
+
|
|
189
|
+
This method filters out unknown parameters, given the lr scheduler name.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
user_params : dict
|
|
194
|
+
Parameters passed on to the torch lr scheduler.
|
|
195
|
+
values : FieldValidationInfo
|
|
196
|
+
Pydantic field validation info, used to get the lr scheduler name.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
Dict
|
|
201
|
+
Filtered lr scheduler parameters.
|
|
202
|
+
|
|
203
|
+
Raises
|
|
204
|
+
------
|
|
205
|
+
ValueError
|
|
206
|
+
If the lr scheduler name is not specified.
|
|
207
|
+
"""
|
|
208
|
+
if "name" in values.data:
|
|
209
|
+
lr_scheduler_name = values.data["name"]
|
|
210
|
+
|
|
211
|
+
# retrieve the corresponding lr scheduler class
|
|
212
|
+
lr_scheduler_class = getattr(optim.lr_scheduler, lr_scheduler_name)
|
|
213
|
+
|
|
214
|
+
# filter the user parameters according to the lr scheduler's signature
|
|
215
|
+
return get_parameters(lr_scheduler_class, user_params)
|
|
216
|
+
else:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
"Cannot validate lr scheduler parameters without `name`, check that it "
|
|
219
|
+
"has correctly been specified."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
@model_validator(mode="after")
|
|
223
|
+
def step_lr_step_size_parameter(cls, lr_scheduler: LrScheduler) -> LrScheduler:
|
|
224
|
+
"""
|
|
225
|
+
Check that StepLR lr scheduler has `step_size` parameter specified.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
lr_scheduler : LrScheduler
|
|
230
|
+
Lr scheduler to validate.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
LrScheduler
|
|
235
|
+
Validated lr scheduler.
|
|
236
|
+
|
|
237
|
+
Raises
|
|
238
|
+
------
|
|
239
|
+
ValueError
|
|
240
|
+
If the lr scheduler is StepLR and the step_size parameter is not specified.
|
|
241
|
+
"""
|
|
242
|
+
if (
|
|
243
|
+
lr_scheduler.name == TorchLRScheduler.StepLR
|
|
244
|
+
and "step_size" not in lr_scheduler.parameters
|
|
245
|
+
):
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"StepLR lr scheduler requires `step_size` parameter, check that it has "
|
|
248
|
+
"correctly been specified in `parameters`."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return lr_scheduler
|
|
252
|
+
|
|
253
|
+
def model_dump(
|
|
254
|
+
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
255
|
+
) -> Dict:
|
|
256
|
+
"""
|
|
257
|
+
Override model_dump method.
|
|
258
|
+
|
|
259
|
+
The purpose of this method is to ensure smooth export to yaml. It includes:
|
|
260
|
+
- removing entries with None value.
|
|
261
|
+
- removing optional values if they have the default value.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
exclude_optionals : bool, optional
|
|
266
|
+
Whether to exclude optional arguments if they are default, by default True.
|
|
267
|
+
*args : List
|
|
268
|
+
Positional arguments, unused.
|
|
269
|
+
**kwargs : Dict
|
|
270
|
+
Keyword arguments, unused.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
dict
|
|
275
|
+
Dictionary containing the model parameters.
|
|
276
|
+
"""
|
|
277
|
+
dictionary = super().model_dump(exclude_none=True)
|
|
278
|
+
|
|
279
|
+
if exclude_optionals:
|
|
280
|
+
# remove optional arguments if they are default
|
|
281
|
+
default_optionals: dict = {"parameters": {}}
|
|
282
|
+
remove_default_optionals(dictionary, default_optionals)
|
|
283
|
+
|
|
284
|
+
return dictionary
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class AMP(BaseModel):
|
|
288
|
+
"""
|
|
289
|
+
Automatic mixed precision (AMP) parameters.
|
|
290
|
+
|
|
291
|
+
See: https://pytorch.org/docs/stable/amp.html.
|
|
292
|
+
|
|
293
|
+
Attributes
|
|
294
|
+
----------
|
|
295
|
+
use : bool, optional
|
|
296
|
+
Whether to use AMP or not, default False.
|
|
297
|
+
init_scale : int, optional
|
|
298
|
+
Initial scale used for loss scaling, default 1024.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
model_config = ConfigDict(
|
|
302
|
+
validate_assignment=True,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
use: bool = False
|
|
306
|
+
|
|
307
|
+
# TODO review init_scale and document better
|
|
308
|
+
init_scale: int = Field(default=1024, ge=512, le=65536)
|
|
309
|
+
|
|
310
|
+
@field_validator("init_scale")
|
|
311
|
+
def power_of_two(cls, scale: int) -> int:
|
|
312
|
+
"""
|
|
313
|
+
Validate that init_scale is a power of two.
|
|
314
|
+
|
|
315
|
+
Parameters
|
|
316
|
+
----------
|
|
317
|
+
scale : int
|
|
318
|
+
Initial scale used for loss scaling.
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
int
|
|
323
|
+
Validated initial scale.
|
|
324
|
+
|
|
325
|
+
Raises
|
|
326
|
+
------
|
|
327
|
+
ValueError
|
|
328
|
+
If the init_scale is not a power of two.
|
|
329
|
+
"""
|
|
330
|
+
if not scale & (scale - 1) == 0:
|
|
331
|
+
raise ValueError(f"Init scale must be a power of two (got {scale}).")
|
|
332
|
+
|
|
333
|
+
return scale
|
|
334
|
+
|
|
335
|
+
def model_dump(
|
|
336
|
+
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
337
|
+
) -> Dict:
|
|
338
|
+
"""
|
|
339
|
+
Override model_dump method.
|
|
340
|
+
|
|
341
|
+
The purpose is to ensure export smooth import to yaml. It includes:
|
|
342
|
+
- remove entries with None value.
|
|
343
|
+
- remove optional values if they have the default value.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
exclude_optionals : bool, optional
|
|
348
|
+
Whether to exclude optional arguments if they are default, by default True.
|
|
349
|
+
*args : List
|
|
350
|
+
Positional arguments, unused.
|
|
351
|
+
**kwargs : Dict
|
|
352
|
+
Keyword arguments, unused.
|
|
353
|
+
|
|
354
|
+
Returns
|
|
355
|
+
-------
|
|
356
|
+
dict
|
|
357
|
+
Dictionary containing the model parameters.
|
|
358
|
+
"""
|
|
359
|
+
dictionary = super().model_dump(exclude_none=True)
|
|
360
|
+
|
|
361
|
+
if exclude_optionals:
|
|
362
|
+
# remove optional arguments if they are default
|
|
363
|
+
defaults = {
|
|
364
|
+
"init_scale": 1024,
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
remove_default_optionals(dictionary, defaults)
|
|
368
|
+
|
|
369
|
+
return dictionary
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class Training(BaseModel):
|
|
373
|
+
"""
|
|
374
|
+
Parameters related to the training.
|
|
375
|
+
|
|
376
|
+
Mandatory parameters are:
|
|
377
|
+
- num_epochs: number of epochs, greater than 0.
|
|
378
|
+
- patch_size: patch size, 2D or 3D, non-zero and divisible by 2.
|
|
379
|
+
- batch_size: batch size, greater than 0.
|
|
380
|
+
- optimizer: optimizer, see `Optimizer`.
|
|
381
|
+
- lr_scheduler: learning rate scheduler, see `LrScheduler`.
|
|
382
|
+
- augmentation: whether to use data augmentation or not (True or False).
|
|
383
|
+
|
|
384
|
+
The other fields are optional:
|
|
385
|
+
- use_wandb: whether to use wandb or not (default True).
|
|
386
|
+
- num_workers: number of workers (default 0).
|
|
387
|
+
- amp: automatic mixed precision parameters (disabled by default).
|
|
388
|
+
|
|
389
|
+
Attributes
|
|
390
|
+
----------
|
|
391
|
+
num_epochs : int
|
|
392
|
+
Number of epochs, greater than 0.
|
|
393
|
+
patch_size : conlist(int, min_length=2, max_length=3)
|
|
394
|
+
Patch size, 2D or 3D, non-zero and divisible by 2.
|
|
395
|
+
batch_size : int
|
|
396
|
+
Batch size, greater than 0.
|
|
397
|
+
optimizer : Optimizer
|
|
398
|
+
Optimizer.
|
|
399
|
+
lr_scheduler : LrScheduler
|
|
400
|
+
Learning rate scheduler.
|
|
401
|
+
augmentation : bool
|
|
402
|
+
Whether to use data augmentation or not.
|
|
403
|
+
use_wandb : bool
|
|
404
|
+
Optional, whether to use wandb or not (default True).
|
|
405
|
+
num_workers : int
|
|
406
|
+
Optional, number of workers (default 0).
|
|
407
|
+
amp : AMP
|
|
408
|
+
Optional, automatic mixed precision parameters (disabled by default).
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
# Pydantic class configuration
|
|
412
|
+
model_config = ConfigDict(
|
|
413
|
+
use_enum_values=True,
|
|
414
|
+
validate_assignment=True,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Mandatory fields
|
|
418
|
+
num_epochs: int
|
|
419
|
+
patch_size: List[int] = Field(..., min_length=2, max_length=3)
|
|
420
|
+
batch_size: int
|
|
421
|
+
|
|
422
|
+
optimizer: Optimizer
|
|
423
|
+
lr_scheduler: LrScheduler
|
|
424
|
+
|
|
425
|
+
augmentation: bool
|
|
426
|
+
|
|
427
|
+
# Optional fields
|
|
428
|
+
use_wandb: bool = False
|
|
429
|
+
num_workers: int = Field(default=0, ge=0)
|
|
430
|
+
amp: AMP = AMP()
|
|
431
|
+
|
|
432
|
+
@field_validator("num_epochs", "batch_size")
|
|
433
|
+
def greater_than_0(cls, val: int) -> int:
|
|
434
|
+
"""
|
|
435
|
+
Validate number of epochs.
|
|
436
|
+
|
|
437
|
+
Number of epochs must be greater than 0.
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
val : int
|
|
442
|
+
Number of epochs.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
int
|
|
447
|
+
Validated number of epochs.
|
|
448
|
+
|
|
449
|
+
Raises
|
|
450
|
+
------
|
|
451
|
+
ValueError
|
|
452
|
+
If the number of epochs is 0.
|
|
453
|
+
"""
|
|
454
|
+
if val < 1:
|
|
455
|
+
raise ValueError(f"Number of epochs must be greater than 0 (got {val}).")
|
|
456
|
+
|
|
457
|
+
return val
|
|
458
|
+
|
|
459
|
+
@field_validator("patch_size")
|
|
460
|
+
def all_elements_non_zero_divisible_by_2(cls, patch_list: List[int]) -> List[int]:
|
|
461
|
+
"""
|
|
462
|
+
Validate patch size.
|
|
463
|
+
|
|
464
|
+
Patch size must be non-zero, positive and divisible by 2.
|
|
465
|
+
|
|
466
|
+
Parameters
|
|
467
|
+
----------
|
|
468
|
+
patch_list : List[int]
|
|
469
|
+
Patch size.
|
|
470
|
+
|
|
471
|
+
Returns
|
|
472
|
+
-------
|
|
473
|
+
List[int]
|
|
474
|
+
Validated patch size.
|
|
475
|
+
|
|
476
|
+
Raises
|
|
477
|
+
------
|
|
478
|
+
ValueError
|
|
479
|
+
If the patch size is 0.
|
|
480
|
+
ValueError
|
|
481
|
+
If the patch size is not divisible by 2.
|
|
482
|
+
"""
|
|
483
|
+
for dim in patch_list:
|
|
484
|
+
if dim < 1:
|
|
485
|
+
raise ValueError(f"Patch size must be non-zero positive (got {dim}).")
|
|
486
|
+
|
|
487
|
+
if dim % 2 != 0:
|
|
488
|
+
raise ValueError(f"Patch size must be divisible by 2 (got {dim}).")
|
|
489
|
+
|
|
490
|
+
return patch_list
|
|
491
|
+
|
|
492
|
+
def model_dump(
|
|
493
|
+
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
494
|
+
) -> Dict:
|
|
495
|
+
"""
|
|
496
|
+
Override model_dump method.
|
|
497
|
+
|
|
498
|
+
The purpose is to ensure export smooth import to yaml. It includes:
|
|
499
|
+
- remove entries with None value.
|
|
500
|
+
- remove optional values if they have the default value.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
exclude_optionals : bool, optional
|
|
505
|
+
Whether to exclude optional arguments if they are default, by default True.
|
|
506
|
+
*args : List
|
|
507
|
+
Positional arguments, unused.
|
|
508
|
+
**kwargs : Dict
|
|
509
|
+
Keyword arguments, unused.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
dict
|
|
514
|
+
Dictionary containing the model parameters.
|
|
515
|
+
"""
|
|
516
|
+
dictionary = super().model_dump(exclude_none=True)
|
|
517
|
+
|
|
518
|
+
dictionary["optimizer"] = self.optimizer.model_dump(exclude_optionals)
|
|
519
|
+
dictionary["lr_scheduler"] = self.lr_scheduler.model_dump(exclude_optionals)
|
|
520
|
+
|
|
521
|
+
if self.amp is not None:
|
|
522
|
+
dictionary["amp"] = self.amp.model_dump(exclude_optionals)
|
|
523
|
+
|
|
524
|
+
if exclude_optionals:
|
|
525
|
+
# remove optional arguments if they are default
|
|
526
|
+
defaults = {
|
|
527
|
+
"use_wandb": False,
|
|
528
|
+
"num_workers": 0,
|
|
529
|
+
"amp": AMP().model_dump(),
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
remove_default_optionals(dictionary, defaults)
|
|
533
|
+
|
|
534
|
+
return dictionary
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Dataset module."""
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Convenience methods for datasets."""
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import tifffile
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def list_files(data_path: Union[str, Path], data_format: str) -> List[Path]:
|
|
11
|
+
"""
|
|
12
|
+
Return a list of path to files in a directory.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
data_path : str
|
|
17
|
+
Path to the folder containing the data.
|
|
18
|
+
data_format : str
|
|
19
|
+
Extension of the files to load, without period, e.g. `tif`.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
List[Path]
|
|
24
|
+
List of pathlib.Path objects.
|
|
25
|
+
"""
|
|
26
|
+
files = sorted(Path(data_path).rglob(f"*.{data_format}*"))
|
|
27
|
+
return files
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _update_axes(array: np.ndarray, axes: str) -> np.ndarray:
|
|
31
|
+
"""
|
|
32
|
+
Update axes of the sample to match the config axes.
|
|
33
|
+
|
|
34
|
+
This method concatenate the S and T axes.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
array : np.ndarray
|
|
39
|
+
Input array.
|
|
40
|
+
axes : str
|
|
41
|
+
Description of axes in format STCZYX.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
np.ndarray
|
|
46
|
+
Updated array.
|
|
47
|
+
"""
|
|
48
|
+
# concatenate ST axes to N, return NCZYX
|
|
49
|
+
if "S" in axes or "T" in axes:
|
|
50
|
+
new_axes_len = len(axes.replace("Z", "").replace("YX", ""))
|
|
51
|
+
# TODO test reshape as it can scramble data, moveaxis is probably better
|
|
52
|
+
array = array.reshape(-1, *array.shape[new_axes_len:]).astype(np.float32)
|
|
53
|
+
|
|
54
|
+
else:
|
|
55
|
+
array = np.expand_dims(array, axis=0).astype(np.float32)
|
|
56
|
+
|
|
57
|
+
return array
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def read_tiff(file_path: Path, axes: str) -> np.ndarray:
|
|
61
|
+
"""
|
|
62
|
+
Read a tiff file and return a numpy array.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
file_path : Path
|
|
67
|
+
Path to a file.
|
|
68
|
+
axes : str
|
|
69
|
+
Description of axes in format STCZYX.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
np.ndarray
|
|
74
|
+
Resulting array.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ValueError
|
|
79
|
+
If the file failed to open.
|
|
80
|
+
OSError
|
|
81
|
+
If the file failed to open.
|
|
82
|
+
ValueError
|
|
83
|
+
If the file is not a valid tiff.
|
|
84
|
+
ValueError
|
|
85
|
+
If the data dimensions are incorrect.
|
|
86
|
+
ValueError
|
|
87
|
+
If the axes length is incorrect.
|
|
88
|
+
"""
|
|
89
|
+
if file_path.suffix[:4] == ".tif":
|
|
90
|
+
try:
|
|
91
|
+
sample = tifffile.imread(file_path)
|
|
92
|
+
except (ValueError, OSError) as e:
|
|
93
|
+
logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
|
|
94
|
+
raise e
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
97
|
+
|
|
98
|
+
sample = sample.squeeze()
|
|
99
|
+
|
|
100
|
+
if len(sample.shape) < 2 or len(sample.shape) > 4:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {sample.shape} for"
|
|
103
|
+
f"file {file_path})."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# check number of axes
|
|
107
|
+
if len(axes) != len(sample.shape):
|
|
108
|
+
raise ValueError(f"Incorrect axes length (got {axes} for file {file_path}).")
|
|
109
|
+
sample = _update_axes(sample, axes)
|
|
110
|
+
|
|
111
|
+
return sample
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Extraction strategy module.
|
|
3
|
+
|
|
4
|
+
This module defines the various extraction strategies available in CAREamics.
|
|
5
|
+
"""
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ExtractionStrategy(str, Enum):
|
|
10
|
+
"""
|
|
11
|
+
Available extraction strategies.
|
|
12
|
+
|
|
13
|
+
Currently supported:
|
|
14
|
+
- random: random extraction.
|
|
15
|
+
- sequential: grid extraction, can miss edge values.
|
|
16
|
+
- tiled: tiled extraction, covers the whole image.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
RANDOM = "random"
|
|
20
|
+
SEQUENTIAL = "sequential"
|
|
21
|
+
TILED = "tiled"
|