careamics 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/dataset_utils/running_stats.py +7 -3
  18. careamics/dataset/in_memory_dataset.py +3 -10
  19. careamics/dataset/in_memory_pred_dataset.py +3 -5
  20. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +2 -2
  22. careamics/dataset/iterable_pred_dataset.py +3 -5
  23. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  24. careamics/dataset_ng/dataset/__init__.py +3 -0
  25. careamics/dataset_ng/dataset/dataset.py +184 -0
  26. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  27. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  28. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  29. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  30. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  34. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  35. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  37. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  38. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  39. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  40. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  41. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  42. careamics/lightning/lightning_module.py +78 -27
  43. careamics/lightning/train_data_module.py +8 -39
  44. careamics/losses/fcn/losses.py +17 -10
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/METADATA +7 -8
  56. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -2,76 +2,25 @@
2
2
 
3
3
  from typing import Annotated, Any, Literal, Optional, Union
4
4
 
5
- from pydantic import Discriminator, Tag, TypeAdapter
5
+ from pydantic import Field, TypeAdapter
6
6
 
7
7
  from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
8
  from careamics.config.architectures import UNetModel
9
- from careamics.config.care_configuration import CAREConfiguration
10
- from careamics.config.configuration import Configuration
11
- from careamics.config.data import DataConfig, N2VDataConfig
12
- from careamics.config.n2n_configuration import N2NConfiguration
13
- from careamics.config.n2v_configuration import N2VConfiguration
9
+ from careamics.config.data import DataConfig
14
10
  from careamics.config.support import (
15
- SupportedAlgorithm,
16
11
  SupportedArchitecture,
17
12
  SupportedPixelManipulation,
18
13
  SupportedTransform,
19
14
  )
20
15
  from careamics.config.training_model import TrainingConfig
21
16
  from careamics.config.transformations import (
22
- N2V_TRANSFORMS_UNION,
23
17
  SPATIAL_TRANSFORMS_UNION,
24
18
  N2VManipulateModel,
25
19
  XYFlipModel,
26
20
  XYRandomRotate90Model,
27
21
  )
28
22
 
29
-
30
- def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
31
- """Discriminate algorithm-specific configurations based on the algorithm.
32
-
33
- Parameters
34
- ----------
35
- value : Any
36
- Value to discriminate.
37
-
38
- Returns
39
- -------
40
- str
41
- Discriminator value.
42
- """
43
- if isinstance(value, dict):
44
- return value["algorithm_config"]["algorithm"]
45
- return value.algorithm_config.algorithm
46
-
47
-
48
- def configuration_factory(
49
- configuration: dict[str, Any]
50
- ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
51
- """
52
- Create a configuration for training CAREamics.
53
-
54
- Parameters
55
- ----------
56
- configuration : dict
57
- Configuration dictionary.
58
-
59
- Returns
60
- -------
61
- N2VConfiguration or N2NConfiguration or CAREConfiguration
62
- Configuration for training CAREamics.
63
- """
64
- adapter: TypeAdapter = TypeAdapter(
65
- Annotated[
66
- Union[
67
- Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
68
- Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
69
- Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
70
- ],
71
- Discriminator(_algorithm_config_discriminator),
72
- ]
73
- )
74
- return adapter.validate_python(configuration)
23
+ from .configuration import Configuration
75
24
 
76
25
 
77
26
  def algorithm_factory(
@@ -90,28 +39,15 @@ def algorithm_factory(
90
39
  N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
91
40
  Algorithm model for training CAREamics.
92
41
  """
93
- adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
42
+ adapter: TypeAdapter = TypeAdapter(
43
+ Annotated[
44
+ Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm],
45
+ Field(discriminator="algorithm"),
46
+ ]
47
+ )
94
48
  return adapter.validate_python(algorithm)
95
49
 
96
50
 
97
- def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
98
- """
99
- Create a data model for training CAREamics.
100
-
101
- Parameters
102
- ----------
103
- data : dict
104
- Data dictionary.
105
-
106
- Returns
107
- -------
108
- DataConfig or N2VDataConfig
109
- Data model for training CAREamics.
110
- """
111
- adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
112
- return adapter.validate_python(data)
113
-
114
-
115
51
  def _list_spatial_augmentations(
116
52
  augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
117
53
  ) -> list[SPATIAL_TRANSFORMS_UNION]:
@@ -208,70 +144,42 @@ def _create_unet_configuration(
208
144
  )
209
145
 
210
146
 
211
- def _create_configuration(
212
- algorithm: Literal["n2v", "care", "n2n"],
213
- experiment_name: str,
214
- data_type: Literal["array", "tiff", "custom"],
147
+ def _create_algorithm_configuration(
215
148
  axes: str,
216
- patch_size: list[int],
217
- batch_size: int,
218
- num_epochs: int,
219
- augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
220
- independent_channels: bool,
149
+ algorithm: Literal["n2v", "care", "n2n"],
221
150
  loss: Literal["n2v", "mae", "mse"],
151
+ independent_channels: bool,
222
152
  n_channels_in: int,
223
153
  n_channels_out: int,
224
- logger: Literal["wandb", "tensorboard", "none"],
225
154
  use_n2v2: bool = False,
226
155
  model_params: Optional[dict] = None,
227
- train_dataloader_params: Optional[dict[str, Any]] = None,
228
- val_dataloader_params: Optional[dict[str, Any]] = None,
229
- ) -> Configuration:
156
+ ) -> dict:
230
157
  """
231
- Create a configuration for training N2V, CARE or Noise2Noise.
158
+ Create a dictionary with the parameters of the algorithm model.
232
159
 
233
160
  Parameters
234
161
  ----------
162
+ axes : str
163
+ Axes of the data.
235
164
  algorithm : {"n2v", "care", "n2n"}
236
165
  Algorithm to use.
237
- experiment_name : str
238
- Name of the experiment.
239
- data_type : {"array", "tiff", "custom"}
240
- Type of the data.
241
- axes : str
242
- Axes of the data (e.g. SYX).
243
- patch_size : list of int
244
- Size of the patches along the spatial dimensions (e.g. [64, 64]).
245
- batch_size : int
246
- Batch size.
247
- num_epochs : int
248
- Number of epochs.
249
- augmentations : list of transforms
250
- List of transforms to apply, either both or one of XYFlipModel and
251
- XYRandomRotate90Model.
252
- independent_channels : bool
253
- Whether to train all channels independently.
254
166
  loss : {"n2v", "mae", "mse"}
255
167
  Loss function to use.
168
+ independent_channels : bool
169
+ Whether to train all channels independently.
256
170
  n_channels_in : int
257
- Number of channels in.
171
+ Number of input channels.
258
172
  n_channels_out : int
259
- Number of channels out.
260
- logger : {"wandb", "tensorboard", "none"}
261
- Logger to use.
173
+ Number of output channels.
262
174
  use_n2v2 : bool, optional
263
175
  Whether to use N2V2, by default False.
264
176
  model_params : dict
265
177
  UNetModel parameters.
266
- train_dataloader_params : dict
267
- Parameters for the training dataloader, see PyTorch notes, by default None.
268
- val_dataloader_params : dict
269
- Parameters for the validation dataloader, see PyTorch notes, by default None.
270
178
 
271
179
  Returns
272
180
  -------
273
- Configuration
274
- Configuration for training N2V, CARE or Noise2Noise.
181
+ dict
182
+ Algorithm model as dictionnary with the specified parameters.
275
183
  """
276
184
  # model
277
185
  unet_model = _create_unet_configuration(
@@ -283,13 +191,47 @@ def _create_configuration(
283
191
  model_params=model_params,
284
192
  )
285
193
 
286
- # algorithm model
287
- algorithm_config = {
194
+ return {
288
195
  "algorithm": algorithm,
289
196
  "loss": loss,
290
197
  "model": unet_model,
291
198
  }
292
199
 
200
+
201
+ def _create_data_configuration(
202
+ data_type: Literal["array", "tiff", "custom"],
203
+ axes: str,
204
+ patch_size: list[int],
205
+ batch_size: int,
206
+ augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
207
+ train_dataloader_params: Optional[dict[str, Any]] = None,
208
+ val_dataloader_params: Optional[dict[str, Any]] = None,
209
+ ) -> DataConfig:
210
+ """
211
+ Create a dictionary with the parameters of the data model.
212
+
213
+ Parameters
214
+ ----------
215
+ data_type : {"array", "tiff", "custom"}
216
+ Type of the data.
217
+ axes : str
218
+ Axes of the data.
219
+ patch_size : list of int
220
+ Size of the patches along the spatial dimensions.
221
+ batch_size : int
222
+ Batch size.
223
+ augmentations : list of transforms
224
+ List of transforms to apply.
225
+ train_dataloader_params : dict
226
+ Parameters for the training dataloader, see PyTorch notes, by default None.
227
+ val_dataloader_params : dict
228
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
229
+
230
+ Returns
231
+ -------
232
+ DataConfig
233
+ Data model with the specified parameters.
234
+ """
293
235
  # data model
294
236
  data = {
295
237
  "data_type": data_type,
@@ -300,30 +242,44 @@ def _create_configuration(
300
242
  }
301
243
  # Don't override defaults set in DataConfig class
302
244
  if train_dataloader_params is not None:
245
+ # DataConfig enforces the presence of `shuffle` key in the dataloader parameters
246
+ if "shuffle" not in train_dataloader_params:
247
+ train_dataloader_params["shuffle"] = True
248
+
303
249
  data["train_dataloader_params"] = train_dataloader_params
250
+
304
251
  if val_dataloader_params is not None:
305
252
  data["val_dataloader_params"] = val_dataloader_params
306
253
 
307
- # training model
308
- training = TrainingConfig(
254
+ return DataConfig(**data)
255
+
256
+
257
+ def _create_training_configuration(
258
+ num_epochs: int, logger: Literal["wandb", "tensorboard", "none"]
259
+ ) -> TrainingConfig:
260
+ """
261
+ Create a dictionary with the parameters of the training model.
262
+
263
+ Parameters
264
+ ----------
265
+ num_epochs : int
266
+ Number of epochs.
267
+ logger : {"wandb", "tensorboard", "none"}
268
+ Logger to use.
269
+
270
+ Returns
271
+ -------
272
+ TrainingConfig
273
+ Training model with the specified parameters.
274
+ """
275
+ return TrainingConfig(
309
276
  num_epochs=num_epochs,
310
- batch_size=batch_size,
311
277
  logger=None if logger == "none" else logger,
312
278
  )
313
279
 
314
- # create configuration
315
- configuration = {
316
- "experiment_name": experiment_name,
317
- "algorithm_config": algorithm_config,
318
- "data_config": data,
319
- "training_config": training,
320
- }
321
-
322
- return configuration_factory(configuration)
323
-
324
280
 
325
281
  # TODO reconsider naming once we officially support LVAE approaches
326
- def _create_supervised_configuration(
282
+ def _create_supervised_config_dict(
327
283
  algorithm: Literal["care", "n2n"],
328
284
  experiment_name: str,
329
285
  data_type: Literal["array", "tiff", "custom"],
@@ -331,7 +287,7 @@ def _create_supervised_configuration(
331
287
  patch_size: list[int],
332
288
  batch_size: int,
333
289
  num_epochs: int,
334
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
290
+ augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
335
291
  independent_channels: bool = True,
336
292
  loss: Literal["mae", "mse"] = "mae",
337
293
  n_channels_in: Optional[int] = None,
@@ -340,7 +296,7 @@ def _create_supervised_configuration(
340
296
  model_params: Optional[dict] = None,
341
297
  train_dataloader_params: Optional[dict[str, Any]] = None,
342
298
  val_dataloader_params: Optional[dict[str, Any]] = None,
343
- ) -> Configuration:
299
+ ) -> dict:
344
300
  """
345
301
  Create a configuration for training CARE or Noise2Noise.
346
302
 
@@ -411,25 +367,41 @@ def _create_supervised_configuration(
411
367
  # augmentations
412
368
  spatial_transform_list = _list_spatial_augmentations(augmentations)
413
369
 
414
- return _create_configuration(
370
+ # algorithm
371
+ algorithm_params = _create_algorithm_configuration(
372
+ axes=axes,
415
373
  algorithm=algorithm,
416
- experiment_name=experiment_name,
374
+ loss=loss,
375
+ independent_channels=independent_channels,
376
+ n_channels_in=n_channels_in,
377
+ n_channels_out=n_channels_out,
378
+ model_params=model_params,
379
+ )
380
+
381
+ # data
382
+ data_params = _create_data_configuration(
417
383
  data_type=data_type,
418
384
  axes=axes,
419
385
  patch_size=patch_size,
420
386
  batch_size=batch_size,
421
- num_epochs=num_epochs,
422
387
  augmentations=spatial_transform_list,
423
- independent_channels=independent_channels,
424
- loss=loss,
425
- n_channels_in=n_channels_in,
426
- n_channels_out=n_channels_out,
427
- logger=logger,
428
- model_params=model_params,
429
388
  train_dataloader_params=train_dataloader_params,
430
389
  val_dataloader_params=val_dataloader_params,
431
390
  )
432
391
 
392
+ # training
393
+ training_params = _create_training_configuration(
394
+ num_epochs=num_epochs,
395
+ logger=logger,
396
+ )
397
+
398
+ return {
399
+ "experiment_name": experiment_name,
400
+ "algorithm_config": algorithm_params,
401
+ "data_config": data_params,
402
+ "training_config": training_params,
403
+ }
404
+
433
405
 
434
406
  def create_care_configuration(
435
407
  experiment_name: str,
@@ -580,23 +552,25 @@ def create_care_configuration(
580
552
  ... n_channels_out=1 # if applicable
581
553
  ... )
582
554
  """
583
- return _create_supervised_configuration(
584
- algorithm="care",
585
- experiment_name=experiment_name,
586
- data_type=data_type,
587
- axes=axes,
588
- patch_size=patch_size,
589
- batch_size=batch_size,
590
- num_epochs=num_epochs,
591
- augmentations=augmentations,
592
- independent_channels=independent_channels,
593
- loss=loss,
594
- n_channels_in=n_channels_in,
595
- n_channels_out=n_channels_out,
596
- logger=logger,
597
- model_params=model_params,
598
- train_dataloader_params=train_dataloader_params,
599
- val_dataloader_params=val_dataloader_params,
555
+ return Configuration(
556
+ **_create_supervised_config_dict(
557
+ algorithm="care",
558
+ experiment_name=experiment_name,
559
+ data_type=data_type,
560
+ axes=axes,
561
+ patch_size=patch_size,
562
+ batch_size=batch_size,
563
+ num_epochs=num_epochs,
564
+ augmentations=augmentations,
565
+ independent_channels=independent_channels,
566
+ loss=loss,
567
+ n_channels_in=n_channels_in,
568
+ n_channels_out=n_channels_out,
569
+ logger=logger,
570
+ model_params=model_params,
571
+ train_dataloader_params=train_dataloader_params,
572
+ val_dataloader_params=val_dataloader_params,
573
+ )
600
574
  )
601
575
 
602
576
 
@@ -749,23 +723,25 @@ def create_n2n_configuration(
749
723
  ... n_channels_out=1 # if applicable
750
724
  ... )
751
725
  """
752
- return _create_supervised_configuration(
753
- algorithm="n2n",
754
- experiment_name=experiment_name,
755
- data_type=data_type,
756
- axes=axes,
757
- patch_size=patch_size,
758
- batch_size=batch_size,
759
- num_epochs=num_epochs,
760
- augmentations=augmentations,
761
- independent_channels=independent_channels,
762
- loss=loss,
763
- n_channels_in=n_channels_in,
764
- n_channels_out=n_channels_out,
765
- logger=logger,
766
- model_params=model_params,
767
- train_dataloader_params=train_dataloader_params,
768
- val_dataloader_params=val_dataloader_params,
726
+ return Configuration(
727
+ **_create_supervised_config_dict(
728
+ algorithm="n2n",
729
+ experiment_name=experiment_name,
730
+ data_type=data_type,
731
+ axes=axes,
732
+ patch_size=patch_size,
733
+ batch_size=batch_size,
734
+ num_epochs=num_epochs,
735
+ augmentations=augmentations,
736
+ independent_channels=independent_channels,
737
+ loss=loss,
738
+ n_channels_in=n_channels_in,
739
+ n_channels_out=n_channels_out,
740
+ logger=logger,
741
+ model_params=model_params,
742
+ train_dataloader_params=train_dataloader_params,
743
+ val_dataloader_params=val_dataloader_params,
744
+ )
769
745
  )
770
746
 
771
747
 
@@ -995,24 +971,40 @@ def create_n2v_configuration(
995
971
  struct_mask_axis=struct_n2v_axis,
996
972
  struct_mask_span=struct_n2v_span,
997
973
  )
998
- transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
999
974
 
1000
- return _create_configuration(
1001
- algorithm="n2v",
1002
- experiment_name=experiment_name,
1003
- data_type=data_type,
975
+ # algorithm
976
+ algorithm_params = _create_algorithm_configuration(
1004
977
  axes=axes,
1005
- patch_size=patch_size,
1006
- batch_size=batch_size,
1007
- num_epochs=num_epochs,
1008
- augmentations=transform_list,
1009
- independent_channels=independent_channels,
978
+ algorithm="n2v",
1010
979
  loss="n2v",
1011
- use_n2v2=use_n2v2,
980
+ independent_channels=independent_channels,
1012
981
  n_channels_in=n_channels,
1013
982
  n_channels_out=n_channels,
1014
- logger=logger,
983
+ use_n2v2=use_n2v2,
1015
984
  model_params=model_params,
985
+ )
986
+ algorithm_params["n2v_config"] = n2v_transform
987
+
988
+ # data
989
+ data_params = _create_data_configuration(
990
+ data_type=data_type,
991
+ axes=axes,
992
+ patch_size=patch_size,
993
+ batch_size=batch_size,
994
+ augmentations=spatial_transforms,
1016
995
  train_dataloader_params=train_dataloader_params,
1017
996
  val_dataloader_params=val_dataloader_params,
1018
997
  )
998
+
999
+ # training
1000
+ training_params = _create_training_configuration(
1001
+ num_epochs=num_epochs,
1002
+ logger=logger,
1003
+ )
1004
+
1005
+ return Configuration(
1006
+ experiment_name=experiment_name,
1007
+ algorithm_config=algorithm_params,
1008
+ data_config=data_params,
1009
+ training_config=training_params,
1010
+ )
@@ -5,7 +5,7 @@ from typing import Union
5
5
 
6
6
  import yaml
7
7
 
8
- from careamics.config import Configuration, configuration_factory
8
+ from careamics.config import Configuration
9
9
 
10
10
 
11
11
  def load_configuration(path: Union[str, Path]) -> Configuration:
@@ -35,7 +35,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
35
35
 
36
36
  dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
37
37
 
38
- return configuration_factory(dictionary)
38
+ return Configuration(**dictionary)
39
39
 
40
40
 
41
41
  def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
@@ -2,9 +2,6 @@
2
2
 
3
3
  __all__ = [
4
4
  "DataConfig",
5
- "GeneralDataConfig",
6
- "N2VDataConfig",
7
5
  ]
8
6
 
9
- from .data_model import DataConfig, GeneralDataConfig
10
- from .n2v_data_model import N2VDataConfig
7
+ from .data_model import DataConfig
@@ -19,7 +19,7 @@ from pydantic import (
19
19
  )
20
20
  from typing_extensions import Self
21
21
 
22
- from ..transformations import N2V_TRANSFORMS_UNION, XYFlipModel, XYRandomRotate90Model
22
+ from ..transformations import XYFlipModel, XYRandomRotate90Model
23
23
  from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
24
 
25
25
 
@@ -46,8 +46,46 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
46
46
  """Annotated float type, used to serialize floats to strings."""
47
47
 
48
48
 
49
- class GeneralDataConfig(BaseModel):
50
- """General data configuration."""
49
+ class DataConfig(BaseModel):
50
+ """Data configuration.
51
+
52
+ If std is specified, mean must be specified as well. Note that setting the std first
53
+ and then the mean (if they were both `None` before) will raise a validation error.
54
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
55
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
56
+ of the target could be different from the input data.
57
+
58
+ All supported transforms are defined in the SupportedTransform enum.
59
+
60
+ Examples
61
+ --------
62
+ Minimum example:
63
+
64
+ >>> data = DataConfig(
65
+ ... data_type="array", # defined in SupportedData
66
+ ... patch_size=[128, 128],
67
+ ... batch_size=4,
68
+ ... axes="YX"
69
+ ... )
70
+
71
+ To change the image_means and image_stds of the data:
72
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
73
+
74
+ One can pass also a list of transformations, by keyword, using the
75
+ SupportedTransform value:
76
+ >>> from careamics.config.support import SupportedTransform
77
+ >>> data = DataConfig(
78
+ ... data_type="tiff",
79
+ ... patch_size=[128, 128],
80
+ ... batch_size=4,
81
+ ... axes="YX",
82
+ ... transforms=[
83
+ ... {
84
+ ... "name": "XYFlip",
85
+ ... }
86
+ ... ]
87
+ ... )
88
+ """
51
89
 
52
90
  # Pydantic class configuration
53
91
  model_config = ConfigDict(
@@ -88,10 +126,7 @@ class GeneralDataConfig(BaseModel):
88
126
  """Standard deviations of the target data across channels, used for
89
127
  normalization."""
90
128
 
91
- # defining as Sequence allows assigning subclasses of TransformModel without mypy
92
- # complaining, this is important for instance to differentiate N2VDataConfig and
93
- # DataConfig
94
- transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
129
+ transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
95
130
  default=[
96
131
  XYFlipModel(),
97
132
  XYRandomRotate90Model(),
@@ -104,7 +139,9 @@ class GeneralDataConfig(BaseModel):
104
139
  train_dataloader_params: dict[str, Any] = Field(
105
140
  default={"shuffle": True}, validate_default=True
106
141
  )
107
- """Dictionary of PyTorch training dataloader parameters."""
142
+ """Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
143
+ should include the `shuffle` key, which is set to `True` by default. We strongly
144
+ recommend to keep it as `True` to ensure the best training results."""
108
145
 
109
146
  val_dataloader_params: dict[str, Any] = Field(default={})
110
147
  """Dictionary of PyTorch validation dataloader parameters."""
@@ -207,7 +244,7 @@ class GeneralDataConfig(BaseModel):
207
244
  ):
208
245
  warn(
209
246
  "Dataloader parameters include `shuffle=False`, this will be passed to "
210
- "the training dataloader and may result in bad results.",
247
+ "the training dataloader and may lead to lower quality results.",
211
248
  stacklevel=1,
212
249
  )
213
250
  return train_dataloader_params
@@ -363,56 +400,3 @@ class GeneralDataConfig(BaseModel):
363
400
  Patch size.
364
401
  """
365
402
  self._update(axes=axes, patch_size=patch_size)
366
-
367
-
368
- class DataConfig(GeneralDataConfig):
369
- """
370
- Data configuration.
371
-
372
- If std is specified, mean must be specified as well. Note that setting the std first
373
- and then the mean (if they were both `None` before) will raise a validation error.
374
- Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
375
- to be lists of floats, one for each channel. For supervised tasks, the mean and std
376
- of the target could be different from the input data.
377
-
378
- All supported transforms are defined in the SupportedTransform enum.
379
-
380
- Examples
381
- --------
382
- Minimum example:
383
-
384
- >>> data = DataConfig(
385
- ... data_type="array", # defined in SupportedData
386
- ... patch_size=[128, 128],
387
- ... batch_size=4,
388
- ... axes="YX"
389
- ... )
390
-
391
- To change the image_means and image_stds of the data:
392
- >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
393
-
394
- One can pass also a list of transformations, by keyword, using the
395
- SupportedTransform value:
396
- >>> from careamics.config.support import SupportedTransform
397
- >>> data = DataConfig(
398
- ... data_type="tiff",
399
- ... patch_size=[128, 128],
400
- ... batch_size=4,
401
- ... axes="YX",
402
- ... transforms=[
403
- ... {
404
- ... "name": "XYFlip",
405
- ... }
406
- ... ]
407
- ... )
408
- """
409
-
410
- transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
411
- default=[
412
- XYFlipModel(),
413
- XYRandomRotate90Model(),
414
- ],
415
- validate_default=True,
416
- )
417
- """List of transformations to apply to the data, available transforms are defined
418
- in SupportedTransform. This excludes N2V specific transformations."""
@@ -8,5 +8,5 @@ class SupportedTransform(str, BaseEnum):
8
8
 
9
9
  XY_FLIP = "XYFlip"
10
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
- N2V_MANIPULATE = "N2VManipulate"
12
11
  NORMALIZE = "Normalize"
12
+ N2V_MANIPULATE = "N2VManipulate"