careamics 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -1,10 +1,14 @@
1
1
  """CARE algorithm configuration."""
2
2
 
3
- from typing import Literal
3
+ from typing import Annotated, Literal
4
4
 
5
- from pydantic import field_validator
5
+ from pydantic import AfterValidator
6
6
 
7
7
  from careamics.config.architectures import UNetModel
8
+ from careamics.config.validators import (
9
+ model_without_final_activation,
10
+ model_without_n2v2,
11
+ )
8
12
 
9
13
  from .unet_algorithm_model import UNetBasedAlgorithm
10
14
 
@@ -26,25 +30,9 @@ class CAREAlgorithm(UNetBasedAlgorithm):
26
30
  loss: Literal["mae", "mse"] = "mae"
27
31
  """CARE-compatible loss function."""
28
32
 
29
- @classmethod
30
- @field_validator("model")
31
- def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
32
- """Validate that the model does not have the n2v2 attribute.
33
-
34
- Parameters
35
- ----------
36
- value : UNetModel
37
- Model to validate.
38
-
39
- Returns
40
- -------
41
- UNetModel
42
- The validated model.
43
- """
44
- if value.n2v2:
45
- raise ValueError(
46
- "The N2N algorithm does not support the `n2v2` attribute. "
47
- "Set it to `False`."
48
- )
49
-
50
- return value
33
+ model: Annotated[
34
+ UNetModel,
35
+ AfterValidator(model_without_n2v2),
36
+ AfterValidator(model_without_final_activation),
37
+ ]
38
+ """UNet without a final activation function and without the `n2v2` modifications."""
@@ -1,16 +1,20 @@
1
1
  """N2N Algorithm configuration."""
2
2
 
3
- from typing import Literal
3
+ from typing import Annotated, Literal
4
4
 
5
- from pydantic import field_validator
5
+ from pydantic import AfterValidator
6
6
 
7
7
  from careamics.config.architectures import UNetModel
8
+ from careamics.config.validators import (
9
+ model_without_final_activation,
10
+ model_without_n2v2,
11
+ )
8
12
 
9
13
  from .unet_algorithm_model import UNetBasedAlgorithm
10
14
 
11
15
 
12
16
  class N2NAlgorithm(UNetBasedAlgorithm):
13
- """N2N Algorithm configuration."""
17
+ """Noise2Noise Algorithm configuration."""
14
18
 
15
19
  algorithm: Literal["n2n"] = "n2n"
16
20
  """N2N Algorithm name."""
@@ -18,25 +22,9 @@ class N2NAlgorithm(UNetBasedAlgorithm):
18
22
  loss: Literal["mae", "mse"] = "mae"
19
23
  """N2N-compatible loss function."""
20
24
 
21
- @classmethod
22
- @field_validator("model")
23
- def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
24
- """Validate that the model does not have the n2v2 attribute.
25
-
26
- Parameters
27
- ----------
28
- value : UNetModel
29
- Model to validate.
30
-
31
- Returns
32
- -------
33
- UNetModel
34
- The validated model.
35
- """
36
- if value.n2v2:
37
- raise ValueError(
38
- "The N2N algorithm does not support the `n2v2` attribute. "
39
- "Set it to `False`."
40
- )
41
-
42
- return value
25
+ model: Annotated[
26
+ UNetModel,
27
+ AfterValidator(model_without_n2v2),
28
+ AfterValidator(model_without_final_activation),
29
+ ]
30
+ """UNet without a final activation function and without the `n2v2` modifications."""
@@ -1,9 +1,14 @@
1
1
  """"N2V Algorithm configuration."""
2
2
 
3
- from typing import Literal
3
+ from typing import Annotated, Literal
4
4
 
5
- from pydantic import model_validator
6
- from typing_extensions import Self
5
+ from pydantic import AfterValidator
6
+
7
+ from careamics.config.architectures import UNetModel
8
+ from careamics.config.validators import (
9
+ model_matching_in_out_channels,
10
+ model_without_final_activation,
11
+ )
7
12
 
8
13
  from .unet_algorithm_model import UNetBasedAlgorithm
9
14
 
@@ -17,19 +22,8 @@ class N2VAlgorithm(UNetBasedAlgorithm):
17
22
  loss: Literal["n2v"] = "n2v"
18
23
  """N2V loss function."""
19
24
 
20
- @model_validator(mode="after")
21
- def algorithm_cross_validation(self: Self) -> Self:
22
- """Validate the algorithm model for N2V.
23
-
24
- Returns
25
- -------
26
- Self
27
- The validated model.
28
- """
29
- if self.model.in_channels != self.model.num_classes:
30
- raise ValueError(
31
- "N2V requires the same number of input and output channels. Make "
32
- "sure that `in_channels` and `num_classes` are the same."
33
- )
34
-
35
- return self
25
+ model: Annotated[
26
+ UNetModel,
27
+ AfterValidator(model_matching_in_out_channels),
28
+ AfterValidator(model_without_final_activation),
29
+ ]
@@ -1,8 +1,8 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Literal, Optional, Union
3
+ from typing import Annotated, Any, Literal, Optional, Union
4
4
 
5
- from pydantic import TypeAdapter
5
+ from pydantic import Discriminator, Tag, TypeAdapter
6
6
 
7
7
  from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
8
  from careamics.config.architectures import UNetModel
@@ -12,6 +12,7 @@ from careamics.config.data import DataConfig, N2VDataConfig
12
12
  from careamics.config.n2n_configuration import N2NConfiguration
13
13
  from careamics.config.n2v_configuration import N2VConfiguration
14
14
  from careamics.config.support import (
15
+ SupportedAlgorithm,
15
16
  SupportedArchitecture,
16
17
  SupportedPixelManipulation,
17
18
  SupportedTransform,
@@ -26,6 +27,24 @@ from careamics.config.transformations import (
26
27
  )
27
28
 
28
29
 
30
+ def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
31
+ """Discriminate algorithm-specific configurations based on the algorithm.
32
+
33
+ Parameters
34
+ ----------
35
+ value : Any
36
+ Value to discriminate.
37
+
38
+ Returns
39
+ -------
40
+ str
41
+ Discriminator value.
42
+ """
43
+ if isinstance(value, dict):
44
+ return value["algorithm_config"]["algorithm"]
45
+ return value.algorithm_config.algorithm
46
+
47
+
29
48
  def configuration_factory(
30
49
  configuration: dict[str, Any]
31
50
  ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
@@ -43,7 +62,14 @@ def configuration_factory(
43
62
  Configuration for training CAREamics.
44
63
  """
45
64
  adapter: TypeAdapter = TypeAdapter(
46
- Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
65
+ Annotated[
66
+ Union[
67
+ Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
68
+ Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
69
+ Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
70
+ ],
71
+ Discriminator(_algorithm_config_discriminator),
72
+ ]
47
73
  )
48
74
  return adapter.validate_python(configuration)
49
75
 
@@ -198,7 +224,8 @@ def _create_configuration(
198
224
  logger: Literal["wandb", "tensorboard", "none"],
199
225
  use_n2v2: bool = False,
200
226
  model_params: Optional[dict] = None,
201
- dataloader_params: Optional[dict] = None,
227
+ train_dataloader_params: Optional[dict[str, Any]] = None,
228
+ val_dataloader_params: Optional[dict[str, Any]] = None,
202
229
  ) -> Configuration:
203
230
  """
204
231
  Create a configuration for training N2V, CARE or Noise2Noise.
@@ -236,8 +263,10 @@ def _create_configuration(
236
263
  Whether to use N2V2, by default False.
237
264
  model_params : dict
238
265
  UNetModel parameters.
239
- dataloader_params : dict
240
- Parameters for the dataloader, see PyTorch notes, by default None.
266
+ train_dataloader_params : dict
267
+ Parameters for the training dataloader, see PyTorch notes, by default None.
268
+ val_dataloader_params : dict
269
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
241
270
 
242
271
  Returns
243
272
  -------
@@ -268,8 +297,12 @@ def _create_configuration(
268
297
  "patch_size": patch_size,
269
298
  "batch_size": batch_size,
270
299
  "transforms": augmentations,
271
- "dataloader_params": dataloader_params,
272
300
  }
301
+ # Don't override defaults set in DataConfig class
302
+ if train_dataloader_params is not None:
303
+ data["train_dataloader_params"] = train_dataloader_params
304
+ if val_dataloader_params is not None:
305
+ data["val_dataloader_params"] = val_dataloader_params
273
306
 
274
307
  # training model
275
308
  training = TrainingConfig(
@@ -305,7 +338,8 @@ def _create_supervised_configuration(
305
338
  n_channels_out: Optional[int] = None,
306
339
  logger: Literal["wandb", "tensorboard", "none"] = "none",
307
340
  model_params: Optional[dict] = None,
308
- dataloader_params: Optional[dict] = None,
341
+ train_dataloader_params: Optional[dict[str, Any]] = None,
342
+ val_dataloader_params: Optional[dict[str, Any]] = None,
309
343
  ) -> Configuration:
310
344
  """
311
345
  Create a configuration for training CARE or Noise2Noise.
@@ -342,8 +376,10 @@ def _create_supervised_configuration(
342
376
  Logger to use, by default "none".
343
377
  model_params : dict, optional
344
378
  UNetModel parameters, by default {}.
345
- dataloader_params : dict, optional
346
- Parameters for the dataloader, see PyTorch notes, by default None.
379
+ train_dataloader_params : dict
380
+ Parameters for the training dataloader, see PyTorch notes, by default None.
381
+ val_dataloader_params : dict
382
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
347
383
 
348
384
  Returns
349
385
  -------
@@ -390,7 +426,8 @@ def _create_supervised_configuration(
390
426
  n_channels_out=n_channels_out,
391
427
  logger=logger,
392
428
  model_params=model_params,
393
- dataloader_params=dataloader_params,
429
+ train_dataloader_params=train_dataloader_params,
430
+ val_dataloader_params=val_dataloader_params,
394
431
  )
395
432
 
396
433
 
@@ -408,7 +445,8 @@ def create_care_configuration(
408
445
  n_channels_out: Optional[int] = None,
409
446
  logger: Literal["wandb", "tensorboard", "none"] = "none",
410
447
  model_params: Optional[dict] = None,
411
- dataloader_params: Optional[dict] = None,
448
+ train_dataloader_params: Optional[dict[str, Any]] = None,
449
+ val_dataloader_params: Optional[dict[str, Any]] = None,
412
450
  ) -> Configuration:
413
451
  """
414
452
  Create a configuration for training CARE.
@@ -461,8 +499,14 @@ def create_care_configuration(
461
499
  Logger to use.
462
500
  model_params : dict, default=None
463
501
  UNetModel parameters.
464
- dataloader_params : dict, optional
465
- Parameters for the dataloader, see PyTorch notes, by default None.
502
+ train_dataloader_params : dict, optional
503
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
504
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
505
+ the `GeneralDataConfig`.
506
+ val_dataloader_params : dict, optional
507
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
508
+ If left as `None`, the empty dict `{}` will be used, this is set in the
509
+ `GeneralDataConfig`.
466
510
 
467
511
  Returns
468
512
  -------
@@ -551,7 +595,8 @@ def create_care_configuration(
551
595
  n_channels_out=n_channels_out,
552
596
  logger=logger,
553
597
  model_params=model_params,
554
- dataloader_params=dataloader_params,
598
+ train_dataloader_params=train_dataloader_params,
599
+ val_dataloader_params=val_dataloader_params,
555
600
  )
556
601
 
557
602
 
@@ -569,7 +614,8 @@ def create_n2n_configuration(
569
614
  n_channels_out: Optional[int] = None,
570
615
  logger: Literal["wandb", "tensorboard", "none"] = "none",
571
616
  model_params: Optional[dict] = None,
572
- dataloader_params: Optional[dict] = None,
617
+ train_dataloader_params: Optional[dict[str, Any]] = None,
618
+ val_dataloader_params: Optional[dict[str, Any]] = None,
573
619
  ) -> Configuration:
574
620
  """
575
621
  Create a configuration for training Noise2Noise.
@@ -622,8 +668,14 @@ def create_n2n_configuration(
622
668
  Logger to use, by default "none".
623
669
  model_params : dict, optional
624
670
  UNetModel parameters, by default {}.
625
- dataloader_params : dict, optional
626
- Parameters for the dataloader, see PyTorch notes, by default None.
671
+ train_dataloader_params : dict, optional
672
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
673
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
674
+ the `GeneralDataConfig`.
675
+ val_dataloader_params : dict, optional
676
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
677
+ If left as `None`, the empty dict `{}` will be used, this is set in the
678
+ `GeneralDataConfig`.
627
679
 
628
680
  Returns
629
681
  -------
@@ -712,7 +764,8 @@ def create_n2n_configuration(
712
764
  n_channels_out=n_channels_out,
713
765
  logger=logger,
714
766
  model_params=model_params,
715
- dataloader_params=dataloader_params,
767
+ train_dataloader_params=train_dataloader_params,
768
+ val_dataloader_params=val_dataloader_params,
716
769
  )
717
770
 
718
771
 
@@ -733,7 +786,8 @@ def create_n2v_configuration(
733
786
  struct_n2v_span: int = 5,
734
787
  logger: Literal["wandb", "tensorboard", "none"] = "none",
735
788
  model_params: Optional[dict] = None,
736
- dataloader_params: Optional[dict] = None,
789
+ train_dataloader_params: Optional[dict[str, Any]] = None,
790
+ val_dataloader_params: Optional[dict[str, Any]] = None,
737
791
  ) -> Configuration:
738
792
  """
739
793
  Create a configuration for training Noise2Void.
@@ -812,8 +866,14 @@ def create_n2v_configuration(
812
866
  Logger to use, by default "none".
813
867
  model_params : dict, optional
814
868
  UNetModel parameters, by default None.
815
- dataloader_params : dict, optional
816
- Parameters for the dataloader, see PyTorch notes, by default None.
869
+ train_dataloader_params : dict, optional
870
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
871
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
872
+ the `GeneralDataConfig`.
873
+ val_dataloader_params : dict, optional
874
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
875
+ If left as `None`, the empty dict `{}` will be used, this is set in the
876
+ `GeneralDataConfig`.
817
877
 
818
878
  Returns
819
879
  -------
@@ -953,5 +1013,6 @@ def create_n2v_configuration(
953
1013
  n_channels_out=n_channels,
954
1014
  logger=logger,
955
1015
  model_params=model_params,
956
- dataloader_params=dataloader_params,
1016
+ train_dataloader_params=train_dataloader_params,
1017
+ val_dataloader_params=val_dataloader_params,
957
1018
  )
@@ -5,6 +5,7 @@ from __future__ import annotations
5
5
  from collections.abc import Sequence
6
6
  from pprint import pformat
7
7
  from typing import Annotated, Any, Literal, Optional, Union
8
+ from warnings import warn
8
9
 
9
10
  import numpy as np
10
11
  from numpy.typing import NDArray
@@ -100,8 +101,13 @@ class GeneralDataConfig(BaseModel):
100
101
  """List of transformations to apply to the data, available transforms are defined
101
102
  in SupportedTransform."""
102
103
 
103
- dataloader_params: Optional[dict] = None
104
- """Dictionary of PyTorch dataloader parameters."""
104
+ train_dataloader_params: dict[str, Any] = Field(
105
+ default={"shuffle": True}, validate_default=True
106
+ )
107
+ """Dictionary of PyTorch training dataloader parameters."""
108
+
109
+ val_dataloader_params: dict[str, Any] = Field(default={})
110
+ """Dictionary of PyTorch validation dataloader parameters."""
105
111
 
106
112
  @field_validator("patch_size")
107
113
  @classmethod
@@ -167,6 +173,45 @@ class GeneralDataConfig(BaseModel):
167
173
 
168
174
  return axes
169
175
 
176
+ @field_validator("train_dataloader_params")
177
+ @classmethod
178
+ def shuffle_train_dataloader(
179
+ cls, train_dataloader_params: dict[str, Any]
180
+ ) -> dict[str, Any]:
181
+ """
182
+ Validate that "shuffle" is included in the training dataloader params.
183
+
184
+ A warning will be raised if `shuffle=False`.
185
+
186
+ Parameters
187
+ ----------
188
+ train_dataloader_params : dict of {str: Any}
189
+ The training dataloader parameters.
190
+
191
+ Returns
192
+ -------
193
+ dict of {str: Any}
194
+ The validated training dataloader parameters.
195
+
196
+ Raises
197
+ ------
198
+ ValueError
199
+ If "shuffle" is not included in the training dataloader params.
200
+ """
201
+ if "shuffle" not in train_dataloader_params:
202
+ raise ValueError(
203
+ "Value for 'shuffle' was not included in the `train_dataloader_params`."
204
+ )
205
+ elif ("shuffle" in train_dataloader_params) and (
206
+ not train_dataloader_params["shuffle"]
207
+ ):
208
+ warn(
209
+ "Dataloader parameters include `shuffle=False`, this will be passed to "
210
+ "the training dataloader and may result in bad results.",
211
+ stacklevel=1,
212
+ )
213
+ return train_dataloader_params
214
+
170
215
  @model_validator(mode="after")
171
216
  def std_only_with_mean(self: Self) -> Self:
172
217
  """
@@ -6,7 +6,11 @@ from careamics.utils import BaseEnum
6
6
 
7
7
 
8
8
  class SupportedAlgorithm(str, BaseEnum):
9
- """Algorithms available in CAREamics."""
9
+ """Algorithms available in CAREamics.
10
+
11
+ These definitions are the same as the keyword `name` of the algorithm
12
+ configurations.
13
+ """
10
14
 
11
15
  N2V = "n2v"
12
16
  """Noise2Void algorithm, a self-supervised approach based on blind denoising."""
@@ -1,5 +1,16 @@
1
1
  """Validator utilities."""
2
2
 
3
- __all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"]
3
+ __all__ = [
4
+ "check_axes_validity",
5
+ "model_matching_in_out_channels",
6
+ "model_without_final_activation",
7
+ "model_without_n2v2",
8
+ "patch_size_ge_than_8_power_of_2",
9
+ ]
4
10
 
11
+ from .model_validators import (
12
+ model_matching_in_out_channels,
13
+ model_without_final_activation,
14
+ model_without_n2v2,
15
+ )
5
16
  from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -0,0 +1,84 @@
1
+ """Architecture model validators."""
2
+
3
+ from careamics.config.architectures import UNetModel
4
+
5
+
6
+ def model_without_n2v2(model: UNetModel) -> UNetModel:
7
+ """Validate that the Unet model does not have the n2v2 attribute.
8
+
9
+ Parameters
10
+ ----------
11
+ model : UNetModel
12
+ Model to validate.
13
+
14
+ Returns
15
+ -------
16
+ UNetModel
17
+ The validated model.
18
+
19
+ Raises
20
+ ------
21
+ ValueError
22
+ If the model has the `n2v2` attribute set to `True`.
23
+ """
24
+ if model.n2v2:
25
+ raise ValueError(
26
+ "The algorithm does not support the `n2v2` attribute in the model. "
27
+ "Set it to `False`."
28
+ )
29
+
30
+ return model
31
+
32
+
33
+ def model_without_final_activation(model: UNetModel) -> UNetModel:
34
+ """Validate that the UNet model does not have the final_activation.
35
+
36
+ Parameters
37
+ ----------
38
+ model : UNetModel
39
+ Model to validate.
40
+
41
+ Returns
42
+ -------
43
+ UNetModel
44
+ The validated model.
45
+
46
+ Raises
47
+ ------
48
+ ValueError
49
+ If the model has the final_activation attribute set.
50
+ """
51
+ if model.final_activation != "None":
52
+ raise ValueError(
53
+ "The algorithm does not support a `final_activation` in the model. "
54
+ 'Set it to `"None"`.'
55
+ )
56
+
57
+ return model
58
+
59
+
60
+ def model_matching_in_out_channels(model: UNetModel) -> UNetModel:
61
+ """Validate that the UNet model has the same number of channel inputs and outputs.
62
+
63
+ Parameters
64
+ ----------
65
+ model : UNetModel
66
+ Model to validate.
67
+
68
+ Returns
69
+ -------
70
+ UNetModel
71
+ Validated model.
72
+
73
+ Raises
74
+ ------
75
+ ValueError
76
+ If the model has different number of input and output channels.
77
+ """
78
+ if model.num_classes != model.in_channels:
79
+ raise ValueError(
80
+ "The algorithm requires the same number of input and output channels. "
81
+ "Make sure that `in_channels` and `num_classes` are equal."
82
+ )
83
+
84
+ return model
@@ -5,7 +5,7 @@ from typing import Union
5
5
 
6
6
  from pytorch_lightning import LightningModule, Trainer
7
7
  from pytorch_lightning.callbacks import TQDMProgressBar
8
- from tqdm import tqdm
8
+ from tqdm.auto import tqdm
9
9
 
10
10
 
11
11
  class ProgressBarCallback(TQDMProgressBar):
@@ -2,7 +2,6 @@
2
2
 
3
3
  from pathlib import Path
4
4
  from typing import Any, Callable, Literal, Optional, Union
5
- from warnings import warn
6
5
 
7
6
  import numpy as np
8
7
  import pytorch_lightning as L
@@ -261,11 +260,6 @@ class TrainDataModule(L.LightningDataModule):
261
260
 
262
261
  self.extension_filter: str = extension_filter
263
262
 
264
- # Pytorch dataloader parameters
265
- self.dataloader_params: dict[str, Any] = (
266
- data_config.dataloader_params if data_config.dataloader_params else {}
267
- )
268
-
269
263
  def prepare_data(self) -> None:
270
264
  """
271
265
  Hook used to prepare the data before calling `setup`.
@@ -447,21 +441,17 @@ class TrainDataModule(L.LightningDataModule):
447
441
  Any
448
442
  Training dataloader.
449
443
  """
450
- # check because iterable dataset cannot be shuffled
451
- if not isinstance(self.train_dataset, IterableDataset):
452
- if ("shuffle" in self.dataloader_params) and (
453
- not self.dataloader_params["shuffle"]
454
- ):
455
- warn(
456
- "Dataloader parameters include `shuffle=False`, this will be "
457
- "passed to the training dataloader and may result in bad results.",
458
- stacklevel=1,
459
- )
460
- else:
461
- self.dataloader_params["shuffle"] = True
444
+ train_dataloader_params = self.data_config.train_dataloader_params.copy()
445
+
446
+ # NOTE: When next-gen datasets are completed this can be removed
447
+ # iterable dataset cannot be shuffled
448
+ if isinstance(self.train_dataset, IterableDataset):
449
+ del train_dataloader_params["shuffle"]
462
450
 
463
451
  return DataLoader(
464
- self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
452
+ self.train_dataset,
453
+ batch_size=self.batch_size,
454
+ **train_dataloader_params,
465
455
  )
466
456
 
467
457
  def val_dataloader(self) -> Any:
@@ -476,6 +466,7 @@ class TrainDataModule(L.LightningDataModule):
476
466
  return DataLoader(
477
467
  self.val_dataset,
478
468
  batch_size=self.batch_size,
469
+ **self.data_config.val_dataloader_params,
479
470
  )
480
471
 
481
472