qadence 1.8.0__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. qadence/__init__.py +1 -1
  2. qadence/analog/parse_analog.py +1 -2
  3. qadence/backends/gpsr.py +8 -2
  4. qadence/backends/pulser/backend.py +7 -23
  5. qadence/backends/pyqtorch/backend.py +80 -5
  6. qadence/backends/pyqtorch/config.py +10 -3
  7. qadence/backends/pyqtorch/convert_ops.py +63 -2
  8. qadence/blocks/primitive.py +1 -0
  9. qadence/execution.py +0 -2
  10. qadence/log_config.yaml +10 -0
  11. qadence/measurements/shadow.py +97 -128
  12. qadence/measurements/utils.py +2 -2
  13. qadence/mitigations/readout.py +12 -6
  14. qadence/ml_tools/__init__.py +4 -8
  15. qadence/ml_tools/callbacks/__init__.py +30 -0
  16. qadence/ml_tools/callbacks/callback.py +451 -0
  17. qadence/ml_tools/callbacks/callbackmanager.py +214 -0
  18. qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
  19. qadence/ml_tools/callbacks/writer_registry.py +430 -0
  20. qadence/ml_tools/config.py +132 -258
  21. qadence/ml_tools/data.py +7 -3
  22. qadence/ml_tools/loss/__init__.py +10 -0
  23. qadence/ml_tools/loss/loss.py +87 -0
  24. qadence/ml_tools/optimize_step.py +45 -10
  25. qadence/ml_tools/stages.py +46 -0
  26. qadence/ml_tools/train_utils/__init__.py +7 -0
  27. qadence/ml_tools/train_utils/base_trainer.py +548 -0
  28. qadence/ml_tools/train_utils/config_manager.py +184 -0
  29. qadence/ml_tools/trainer.py +692 -0
  30. qadence/model.py +1 -1
  31. qadence/noise/__init__.py +2 -2
  32. qadence/noise/protocols.py +18 -53
  33. qadence/operations/ham_evo.py +87 -26
  34. qadence/transpile/noise.py +12 -5
  35. qadence/types.py +15 -3
  36. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/METADATA +3 -4
  37. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/RECORD +39 -32
  38. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/WHEEL +1 -1
  39. qadence/ml_tools/printing.py +0 -154
  40. qadence/ml_tools/train_grad.py +0 -395
  41. qadence/ml_tools/train_no_grad.py +0 -199
  42. qadence/noise/readout.py +0 -218
  43. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ from qadence.types import StrEnum
4
+
5
+
6
+ class TrainingStage(StrEnum):
7
+ """Different stages in the training, validation, and testing process."""
8
+
9
+ IDLE = "idle"
10
+ """An 'idle' stage for scenarios where no training, validation, or testing is involved."""
11
+
12
+ TRAIN_START = "train_start"
13
+ """Marks the start of the training process."""
14
+
15
+ TRAIN_END = "train_end"
16
+ """Marks the end of the training process."""
17
+
18
+ TRAIN_EPOCH_START = "train_epoch_start"
19
+ """Indicates the start of a training epoch."""
20
+
21
+ TRAIN_EPOCH_END = "train_epoch_end"
22
+ """Indicates the end of a training epoch."""
23
+
24
+ TRAIN_BATCH_START = "train_batch_start"
25
+ """Marks the start of processing a training batch."""
26
+
27
+ TRAIN_BATCH_END = "train_batch_end"
28
+ """Marks the end of processing a training batch."""
29
+
30
+ VAL_EPOCH_START = "val_epoch_start"
31
+ """Indicates the start of a validation epoch."""
32
+
33
+ VAL_EPOCH_END = "val_epoch_end"
34
+ """Indicates the end of a validation epoch."""
35
+
36
+ VAL_BATCH_START = "val_batch_start"
37
+ """Marks the start of processing a validation batch."""
38
+
39
+ VAL_BATCH_END = "val_batch_end"
40
+ """Marks the end of processing a validation batch."""
41
+
42
+ TEST_BATCH_START = "test_batch_start"
43
+ """Marks the start of processing a test batch."""
44
+
45
+ TEST_BATCH_END = "test_batch_end"
46
+ """Marks the end of processing a test batch."""
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+ from .base_trainer import BaseTrainer
4
+ from .config_manager import ConfigManager
5
+
6
+ # Modules to be automatically added to the qadence.ml_tools.loss namespace
7
+ __all__ = ["BaseTrainer", "ConfigManager"]
@@ -0,0 +1,548 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from logging import getLogger
5
+ from typing import Any, Callable, Iterator
6
+
7
+ import nevergrad as ng
8
+ import torch
9
+ from nevergrad.optimization.base import Optimizer as NGOptimizer
10
+ from torch import nn, optim
11
+ from torch.utils.data import DataLoader
12
+
13
+ from qadence.ml_tools.callbacks import CallbacksManager
14
+ from qadence.ml_tools.config import TrainConfig
15
+ from qadence.ml_tools.data import InfiniteTensorDataset
16
+ from qadence.ml_tools.loss import get_loss_fn
17
+ from qadence.ml_tools.optimize_step import optimize_step
18
+ from qadence.ml_tools.parameters import get_parameters
19
+ from qadence.ml_tools.stages import TrainingStage
20
+
21
+ from .config_manager import ConfigManager
22
+
23
+ logger = getLogger("ml_tools")
24
+
25
+
26
+ class BaseTrainer:
27
+ """Base class for training machine learning models using a given optimizer.
28
+
29
+ The base class implements contextmanager for gradient based/free optimization,
30
+ properties, property setters, input validations, callback decorator generator,
31
+ and empty hooks for different training steps.
32
+
33
+ This class provides:
34
+ - Context managers for enabling/disabling gradient-based optimization
35
+ - Properties for managing models, optimizers, and dataloaders
36
+ - Input validations and a callback decorator generator
37
+ - Config and callback managers using the provided `TrainConfig`
38
+
39
+ Attributes:
40
+ use_grad (bool): Indicates if gradients are used for optimization. Default is True.
41
+
42
+ model (nn.Module): The neural network model.
43
+ optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
44
+ config (TrainConfig): The configuration settings for training.
45
+ train_dataloader (DataLoader | None): DataLoader for training data.
46
+ val_dataloader (DataLoader | None): DataLoader for validation data.
47
+ test_dataloader (DataLoader | None): DataLoader for testing data.
48
+
49
+ optimize_step (Callable): Function for performing an optimization step.
50
+ loss_fn (Callable | str ]): loss function to use. Default loss function
51
+ used is 'mse'
52
+
53
+ num_training_batches (int): Number of training batches. In case of
54
+ InfiniteTensorDataset only 1 batch per epoch is used.
55
+ num_validation_batches (int): Number of validation batches. In case of
56
+ InfiniteTensorDataset only 1 batch per epoch is used.
57
+ num_test_batches (int): Number of test batches. In case of
58
+ InfiniteTensorDataset only 1 batch per epoch is used.
59
+
60
+ state (str): Current state in the training process
61
+ """
62
+
63
+ _use_grad: bool = True
64
+
65
+ def __init__(
66
+ self,
67
+ model: nn.Module,
68
+ optimizer: optim.Optimizer | NGOptimizer | None,
69
+ config: TrainConfig,
70
+ loss_fn: str | Callable = "mse",
71
+ optimize_step: Callable = optimize_step,
72
+ train_dataloader: DataLoader | None = None,
73
+ val_dataloader: DataLoader | None = None,
74
+ test_dataloader: DataLoader | None = None,
75
+ max_batches: int | None = None,
76
+ ):
77
+ """
78
+ Initializes the BaseTrainer.
79
+
80
+ Args:
81
+ model (nn.Module): The model to train.
82
+ optimizer (optim.Optimizer | NGOptimizer | None): The optimizer
83
+ for training.
84
+ config (TrainConfig): The TrainConfig settings for training.
85
+ loss_fn (str | Callable): The loss function to use.
86
+ str input to be specified to use a default loss function.
87
+ currently supported loss functions: 'mse', 'cross_entropy'.
88
+ If not specified, default mse loss will be used.
89
+ train_dataloader (DataLoader | None): DataLoader for training data.
90
+ If the model does not need data to evaluate loss, no dataset
91
+ should be provided.
92
+ val_dataloader (DataLoader | None): DataLoader for validation data.
93
+ test_dataloader (DataLoader | None): DataLoader for testing data.
94
+ max_batches (int | None): Maximum number of batches to process per epoch.
95
+ This is only valid in case of finite TensorDataset dataloaders.
96
+ if max_batches is not None, the maximum number of batches used will
97
+ be min(max_batches, len(dataloader.dataset))
98
+ In case of InfiniteTensorDataset only 1 batch per epoch is used.
99
+ """
100
+ self._model: nn.Module
101
+ self._optimizer: optim.Optimizer | NGOptimizer | None
102
+ self._config: TrainConfig
103
+ self._train_dataloader: DataLoader | None = None
104
+ self._val_dataloader: DataLoader | None = None
105
+ self._test_dataloader: DataLoader | None = None
106
+
107
+ self.config = config
108
+ self.model = model
109
+ self.optimizer = optimizer
110
+ self.max_batches = max_batches
111
+
112
+ self.num_training_batches: int
113
+ self.num_validation_batches: int
114
+ self.num_test_batches: int
115
+
116
+ self.train_dataloader = train_dataloader
117
+ self.val_dataloader = val_dataloader
118
+ self.test_dataloader = test_dataloader
119
+
120
+ self.loss_fn: Callable = get_loss_fn(loss_fn)
121
+ self.optimize_step: Callable = optimize_step
122
+ self.ng_params: ng.p.Array
123
+ self.training_stage: TrainingStage = TrainingStage("idle")
124
+
125
+ @property
126
+ def use_grad(self) -> bool:
127
+ """
128
+ Returns the optimization framework for the trainer.
129
+
130
+ use_grad = True : Gradient based optimization
131
+ use_grad = False : Gradient free optimization
132
+
133
+ Returns:
134
+ bool: Bool value for using gradient.
135
+ """
136
+ return self._use_grad
137
+
138
+ @use_grad.setter
139
+ def use_grad(self, use_grad: bool) -> None:
140
+ """
141
+ Returns the optimization framework for the trainer.
142
+
143
+ use_grad = True : Gradient based optimization
144
+ use_grad = False : Gradient free optimization
145
+
146
+ Returns:
147
+ bool: Bool value for using gradient.
148
+ """
149
+ if not isinstance(use_grad, bool):
150
+ raise TypeError("use_grad must be an True or False.")
151
+ self._use_grad = use_grad
152
+
153
+ @classmethod
154
+ def set_use_grad(cls, value: bool) -> None:
155
+ """
156
+ Sets the global use_grad flag.
157
+
158
+ Args:
159
+ value (bool): Whether to use gradient-based optimization.
160
+ """
161
+ if not isinstance(value, bool):
162
+ raise TypeError("use_grad must be a boolean value.")
163
+ cls._use_grad = value
164
+
165
+ @property
166
+ def model(self) -> nn.Module:
167
+ """
168
+ Returns the model if set, otherwise raises an error.
169
+
170
+ Returns:
171
+ nn.Module: The model.
172
+ """
173
+ if self._model is None:
174
+ raise ValueError("Model has not been set.")
175
+ return self._model
176
+
177
+ @model.setter
178
+ def model(self, model: nn.Module) -> None:
179
+ """
180
+ Sets the model, ensuring it is an instance of nn.Module.
181
+
182
+ Args:
183
+ model (nn.Module): The neural network model.
184
+ """
185
+ if model is not None and not isinstance(model, nn.Module):
186
+ raise TypeError("model must be an instance of nn.Module or None.")
187
+ self._model = model
188
+
189
+ @property
190
+ def optimizer(self) -> optim.Optimizer | NGOptimizer | None:
191
+ """
192
+ Returns the optimizer if set, otherwise raises an error.
193
+
194
+ Returns:
195
+ optim.Optimizer | NGOptimizer | None: The optimizer.
196
+ """
197
+ return self._optimizer
198
+
199
+ @optimizer.setter
200
+ def optimizer(self, optimizer: optim.Optimizer | NGOptimizer | None) -> None:
201
+ """
202
+ Sets the optimizer, checking compatibility with gradient use.
203
+
204
+ We also set up the budget/behavior of different optimizers here.
205
+
206
+ Args:
207
+ optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
208
+ """
209
+ if optimizer is not None:
210
+ if self.use_grad:
211
+ if not isinstance(optimizer, optim.Optimizer):
212
+ raise TypeError("use_grad=True requires a PyTorch optimizer instance.")
213
+ else:
214
+ if not isinstance(optimizer, NGOptimizer):
215
+ raise TypeError("use_grad=False requires a Nevergrad optimizer instance.")
216
+ else:
217
+ optimizer.budget = self.config.max_iter
218
+ optimizer.enable_pickling()
219
+ params = get_parameters(self.model).detach().numpy()
220
+ self.ng_params = ng.p.Array(init=params)
221
+
222
+ self._optimizer = optimizer
223
+
224
+ @property
225
+ def train_dataloader(self) -> DataLoader:
226
+ """
227
+ Returns the training DataLoader, validating its type.
228
+
229
+ Returns:
230
+ DataLoader: The DataLoader for training data.
231
+ """
232
+ return self._train_dataloader
233
+
234
+ @train_dataloader.setter
235
+ def train_dataloader(self, dataloader: DataLoader) -> None:
236
+ """
237
+ Sets the training DataLoader and computes the number of batches.
238
+
239
+ Args:
240
+ dataloader (DataLoader): The DataLoader for training data.
241
+ """
242
+ self._validate_dataloader(dataloader, "train")
243
+ self._train_dataloader = dataloader
244
+ self.num_training_batches = self._compute_num_batches(dataloader)
245
+
246
+ @property
247
+ def val_dataloader(self) -> DataLoader:
248
+ """
249
+ Returns the validation DataLoader, validating its type.
250
+
251
+ Returns:
252
+ DataLoader: The DataLoader for validation data.
253
+ """
254
+ return self._val_dataloader
255
+
256
+ @val_dataloader.setter
257
+ def val_dataloader(self, dataloader: DataLoader) -> None:
258
+ """
259
+ Sets the validation DataLoader and computes the number of batches.
260
+
261
+ Args:
262
+ dataloader (DataLoader): The DataLoader for validation data.
263
+ """
264
+ self._validate_dataloader(dataloader, "val")
265
+ self._val_dataloader = dataloader
266
+ self.num_validation_batches = self._compute_num_batches(dataloader)
267
+
268
+ @property
269
+ def test_dataloader(self) -> DataLoader:
270
+ """
271
+ Returns the test DataLoader, validating its type.
272
+
273
+ Returns:
274
+ DataLoader: The DataLoader for testing data.
275
+ """
276
+ return self._test_dataloader
277
+
278
+ @test_dataloader.setter
279
+ def test_dataloader(self, dataloader: DataLoader) -> None:
280
+ """
281
+ Sets the test DataLoader and computes the number of batches.
282
+
283
+ Args:
284
+ dataloader (DataLoader): The DataLoader for testing data.
285
+ """
286
+ self._validate_dataloader(dataloader, "test")
287
+ self._test_dataloader = dataloader
288
+ self.num_test_batches = self._compute_num_batches(dataloader)
289
+
290
+ @property
291
+ def config(self) -> TrainConfig:
292
+ """
293
+ Returns the training configuration.
294
+
295
+ Returns:
296
+ TrainConfig: The configuration object.
297
+ """
298
+ return self._config
299
+
300
+ @config.setter
301
+ def config(self, value: TrainConfig) -> None:
302
+ """
303
+ Sets the training configuration and initializes callback and config managers.
304
+
305
+ Args:
306
+ value (TrainConfig): The configuration object.
307
+ """
308
+ if value and not isinstance(value, TrainConfig):
309
+ raise TypeError("config must be an instance of TrainConfig.")
310
+ self._config = value
311
+ self.callback_manager = CallbacksManager(value)
312
+ self.config_manager = ConfigManager(value)
313
+
314
+ def _compute_num_batches(self, dataloader: DataLoader) -> int:
315
+ """
316
+ Computes the number of batches for the given DataLoader.
317
+
318
+ Args:
319
+ dataloader (DataLoader): The DataLoader for which to compute
320
+ the number of batches.
321
+ """
322
+ if dataloader is None:
323
+ return 1
324
+ dataset = dataloader.dataset
325
+ if isinstance(dataset, InfiniteTensorDataset):
326
+ return 1
327
+ else:
328
+ n_batches = int(
329
+ (dataset.tensors[0].size(0) + dataloader.batch_size - 1) // dataloader.batch_size
330
+ )
331
+ return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
332
+
333
+ def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None:
334
+ """
335
+ Validates the type of the DataLoader and raises errors for unsupported types.
336
+
337
+ Args:
338
+ dataloader (DataLoader): The DataLoader to validate.
339
+ dataloader_type (str): The type of DataLoader ("train", "val", or "test").
340
+ """
341
+ if dataloader is not None:
342
+ if not isinstance(dataloader, DataLoader):
343
+ raise NotImplementedError(
344
+ f"Unsupported dataloader type: {type(dataloader)}."
345
+ "The dataloader must be an instance of DataLoader."
346
+ )
347
+ if dataloader_type == "val" and self.config.val_every > 0:
348
+ if not isinstance(dataloader, DataLoader):
349
+ raise ValueError(
350
+ "If `config.val_every` is provided as an integer > 0, validation_dataloader"
351
+ "must be an instance of `DataLoader`."
352
+ )
353
+
354
+ @staticmethod
355
+ def callback(phase: str) -> Callable:
356
+ """
357
+ Decorator for executing callbacks before and after a phase.
358
+
359
+ Phase are different hooks during the training. list of valid
360
+ phases is defined in Callbacks.
361
+ We also update the current state of the training process in
362
+ the callback decorator.
363
+
364
+ Args:
365
+ phase (str): The phase for which the callback is executed (e.g., "train",
366
+ "train_epoch", "train_batch").
367
+
368
+ Returns:
369
+ Callable: The decorated function.
370
+ """
371
+
372
+ def decorator(method: Callable) -> Callable:
373
+ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
374
+ start_event = f"{phase}_start"
375
+ end_event = f"{phase}_end"
376
+
377
+ self.training_stage = TrainingStage(start_event)
378
+ self.callback_manager.run_callbacks(trainer=self)
379
+ result = method(self, *args, **kwargs)
380
+
381
+ self.training_stage = TrainingStage(end_event)
382
+ # build_optimize_result method is defined in the trainer.
383
+ self.build_optimize_result(result)
384
+ self.callback_manager.run_callbacks(trainer=self)
385
+
386
+ return result
387
+
388
+ return wrapper
389
+
390
+ return decorator
391
+
392
+ @contextmanager
393
+ def enable_grad_opt(self, optimizer: optim.Optimizer | None = None) -> Iterator[None]:
394
+ """
395
+ Context manager to temporarily enable gradient-based optimization.
396
+
397
+ Args:
398
+ optimizer (optim.Optimizer): The PyTorch optimizer to use.
399
+ If no optimizer is provided, default optimizer for trainer
400
+ object will be used.
401
+ """
402
+ original_mode = self.use_grad
403
+ original_optimizer = self._optimizer
404
+ try:
405
+ self.use_grad = True
406
+ self.callback_manager.use_grad = True
407
+ self.optimizer = optimizer if optimizer else self.optimizer
408
+ yield
409
+ finally:
410
+ self.use_grad = original_mode
411
+ self.callback_manager.use_grad = original_mode
412
+ self.optimizer = original_optimizer
413
+
414
+ @contextmanager
415
+ def disable_grad_opt(self, optimizer: NGOptimizer | None = None) -> Iterator[None]:
416
+ """
417
+ Context manager to temporarily disable gradient-based optimization.
418
+
419
+ Args:
420
+ optimizer (NGOptimizer): The Nevergrad optimizer to use.
421
+ If no optimizer is provided, default optimizer for trainer
422
+ object will be used.
423
+ """
424
+ original_mode = self.use_grad
425
+ original_optimizer = self._optimizer
426
+ try:
427
+ self.use_grad = False
428
+ self.callback_manager.use_grad = False
429
+ self.optimizer = optimizer if optimizer else self.optimizer
430
+ yield
431
+ finally:
432
+ self.use_grad = original_mode
433
+ self.callback_manager.use_grad = original_mode
434
+ self.optimizer = original_optimizer
435
+
436
+ def on_train_start(self) -> None:
437
+ """Called at the start of training."""
438
+ pass
439
+
440
+ def on_train_end(
441
+ self,
442
+ train_losses: list[list[tuple[torch.Tensor, Any]]],
443
+ val_losses: list[list[tuple[torch.Tensor, Any]]] | None = None,
444
+ ) -> None:
445
+ """
446
+ Called at the end of training.
447
+
448
+ Args:
449
+ train_losses (list[list[tuple[torch.Tensor, Any]]]):
450
+ Metrics for the training losses.
451
+ list -> list -> tuples
452
+ Epochs -> Training Batches -> (loss, metrics)
453
+ val_losses (list[list[tuple[torch.Tensor, Any]]] | None):
454
+ Metrics for the validation losses.
455
+ list -> list -> tuples
456
+ Epochs -> Validation Batches -> (loss, metrics)
457
+ """
458
+ pass
459
+
460
+ def on_train_epoch_start(self) -> None:
461
+ """Called at the start of each training epoch."""
462
+ pass
463
+
464
+ def on_train_epoch_end(self, train_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None:
465
+ """
466
+ Called at the end of each training epoch.
467
+
468
+ Args:
469
+ train_epoch_loss_metrics: Metrics for the training epoch losses.
470
+ list -> tuples
471
+ Training Batches -> (loss, metrics)
472
+ """
473
+ pass
474
+
475
+ def on_val_epoch_start(self) -> None:
476
+ """Called at the start of each validation epoch."""
477
+ pass
478
+
479
+ def on_val_epoch_end(self, val_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None:
480
+ """
481
+ Called at the end of each validation epoch.
482
+
483
+ Args:
484
+ val_epoch_loss_metrics: Metrics for the validation epoch loss.
485
+ list -> tuples
486
+ Validation Batches -> (loss, metrics)
487
+ """
488
+ pass
489
+
490
+ def on_train_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
491
+ """
492
+ Called at the start of each training batch.
493
+
494
+ Args:
495
+ batch: A batch of data from the DataLoader. Typically a tuple containing
496
+ input tensors and corresponding target tensors.
497
+ """
498
+ pass
499
+
500
+ def on_train_batch_end(self, train_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
501
+ """
502
+ Called at the end of each training batch.
503
+
504
+ Args:
505
+ train_batch_loss_metrics: Metrics for the training batch loss.
506
+ tuple of (loss, metrics)
507
+ """
508
+ pass
509
+
510
+ def on_val_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
511
+ """
512
+ Called at the start of each validation batch.
513
+
514
+ Args:
515
+ batch: A batch of data from the DataLoader. Typically a tuple containing
516
+ input tensors and corresponding target tensors.
517
+ """
518
+ pass
519
+
520
+ def on_val_batch_end(self, val_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
521
+ """
522
+ Called at the end of each validation batch.
523
+
524
+ Args:
525
+ val_batch_loss_metrics: Metrics for the validation batch loss.
526
+ tuple of (loss, metrics)
527
+ """
528
+ pass
529
+
530
+ def on_test_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
531
+ """
532
+ Called at the start of each testing batch.
533
+
534
+ Args:
535
+ batch: A batch of data from the DataLoader. Typically a tuple containing
536
+ input tensors and corresponding target tensors.
537
+ """
538
+ pass
539
+
540
+ def on_test_batch_end(self, test_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
541
+ """
542
+ Called at the end of each testing batch.
543
+
544
+ Args:
545
+ test_batch_loss_metrics: Metrics for the testing batch loss.
546
+ tuple of (loss, metrics)
547
+ """
548
+ pass