qadence 1.8.0__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. qadence/__init__.py +1 -1
  2. qadence/analog/parse_analog.py +1 -2
  3. qadence/backends/gpsr.py +8 -2
  4. qadence/backends/pulser/backend.py +7 -23
  5. qadence/backends/pyqtorch/backend.py +80 -5
  6. qadence/backends/pyqtorch/config.py +10 -3
  7. qadence/backends/pyqtorch/convert_ops.py +63 -2
  8. qadence/blocks/primitive.py +1 -0
  9. qadence/execution.py +0 -2
  10. qadence/log_config.yaml +10 -0
  11. qadence/measurements/shadow.py +97 -128
  12. qadence/measurements/utils.py +2 -2
  13. qadence/mitigations/readout.py +12 -6
  14. qadence/ml_tools/__init__.py +4 -8
  15. qadence/ml_tools/callbacks/__init__.py +30 -0
  16. qadence/ml_tools/callbacks/callback.py +451 -0
  17. qadence/ml_tools/callbacks/callbackmanager.py +214 -0
  18. qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
  19. qadence/ml_tools/callbacks/writer_registry.py +441 -0
  20. qadence/ml_tools/config.py +132 -258
  21. qadence/ml_tools/data.py +7 -3
  22. qadence/ml_tools/loss/__init__.py +10 -0
  23. qadence/ml_tools/loss/loss.py +87 -0
  24. qadence/ml_tools/optimize_step.py +45 -10
  25. qadence/ml_tools/stages.py +46 -0
  26. qadence/ml_tools/train_utils/__init__.py +7 -0
  27. qadence/ml_tools/train_utils/base_trainer.py +555 -0
  28. qadence/ml_tools/train_utils/config_manager.py +184 -0
  29. qadence/ml_tools/trainer.py +708 -0
  30. qadence/model.py +1 -1
  31. qadence/noise/__init__.py +2 -2
  32. qadence/noise/protocols.py +18 -53
  33. qadence/operations/ham_evo.py +87 -26
  34. qadence/transpile/noise.py +12 -5
  35. qadence/types.py +15 -3
  36. {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/METADATA +3 -4
  37. {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/RECORD +39 -32
  38. {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/WHEEL +1 -1
  39. qadence/ml_tools/printing.py +0 -154
  40. qadence/ml_tools/train_grad.py +0 -395
  41. qadence/ml_tools/train_no_grad.py +0 -199
  42. qadence/noise/readout.py +0 -218
  43. {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,708 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from itertools import islice
5
+ from logging import getLogger
6
+ from typing import Any, Callable, Iterable, cast
7
+
8
+ import torch
9
+ from nevergrad.optimization.base import Optimizer as NGOptimizer
10
+ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
11
+ from torch import complex128, float32, float64, nn, optim
12
+ from torch import device as torch_device
13
+ from torch import dtype as torch_dtype
14
+ from torch.utils.data import DataLoader
15
+
16
+ from qadence.ml_tools.config import TrainConfig
17
+ from qadence.ml_tools.data import DictDataLoader, OptimizeResult
18
+ from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
19
+ from qadence.ml_tools.stages import TrainingStage
20
+
21
+ from .train_utils.base_trainer import BaseTrainer
22
+
23
+ logger = getLogger("ml_tools")
24
+
25
+
26
+ class Trainer(BaseTrainer):
27
+ """Trainer class to manage and execute training, validation, and testing loops for a model (eg.
28
+
29
+ QNN).
30
+
31
+ This class handles the overall training process, including:
32
+ - Managing epochs and steps
33
+ - Handling data loading and batching
34
+ - Computing and updating gradients
35
+ - Logging and monitoring training metrics
36
+
37
+ Attributes:
38
+ current_epoch (int): The current epoch number.
39
+ global_step (int): The global step across all epochs.
40
+ log_device (str): Device for logging, default is "cpu".
41
+ device (torch_device): Device used for computation.
42
+ dtype (torch_dtype | None): Data type used for computation.
43
+ data_dtype (torch_dtype | None): Data type for data.
44
+ Depends on the model's data type.
45
+
46
+ Inherited Attributes:
47
+ use_grad (bool): Indicates if gradients are used for optimization. Default is True.
48
+
49
+ model (nn.Module): The neural network model.
50
+ optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
51
+ config (TrainConfig): The configuration settings for training.
52
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
53
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
54
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
55
+
56
+ optimize_step (Callable): Function for performing an optimization step.
57
+ loss_fn (Callable): loss function to use.
58
+
59
+ num_training_batches (int): Number of training batches.
60
+ num_validation_batches (int): Number of validation batches.
61
+ num_test_batches (int): Number of test batches.
62
+
63
+ state (str): Current state in the training process
64
+
65
+ Default training routine
66
+ ```
67
+ for epoch in max_iter + 1:
68
+ # Training
69
+ for batch in train_batches:
70
+ train model
71
+ # Validation
72
+ if val_every % epoch == 0:
73
+ for batch in val_batches:
74
+ train model
75
+ ```
76
+
77
+ Notes:
78
+ - In case of InfiniteTensorDataset, number of batches = 1.
79
+ - In case of TensorDataset, number of batches are default.
80
+ - Training is run for max_iter + 1 epochs. Epoch 0 logs untrained model.
81
+ - Please look at the CallbackManager initialize_callbacks method to review the default
82
+ logging behavior.
83
+
84
+ Examples:
85
+
86
+ ```python
87
+ import torch
88
+ from torch.optim import SGD
89
+ from qadence import (
90
+ feature_map,
91
+ hamiltonian_factory,
92
+ hea,
93
+ QNN,
94
+ QuantumCircuit,
95
+ TrainConfig,
96
+ Z,
97
+ )
98
+ from qadence.ml_tools.trainer import Trainer
99
+ from qadence.ml_tools.optimize_step import optimize_step
100
+ from qadence.ml_tools import TrainConfig
101
+ from qadence.ml_tools.data import to_dataloader
102
+
103
+ # Initialize the model
104
+ n_qubits = 2
105
+ fm = feature_map(n_qubits)
106
+ ansatz = hea(n_qubits=n_qubits, depth=2)
107
+ observable = hamiltonian_factory(n_qubits, detuning=Z)
108
+ circuit = QuantumCircuit(n_qubits, fm, ansatz)
109
+ model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad")
110
+
111
+ # Set up the optimizer
112
+ optimizer = SGD(model.parameters(), lr=0.001)
113
+
114
+ # Use TrainConfig for configuring the training process
115
+ config = TrainConfig(
116
+ max_iter=100,
117
+ print_every=10,
118
+ write_every=10,
119
+ checkpoint_every=10,
120
+ val_every=10
121
+ )
122
+
123
+ # Create the Trainer instance with TrainConfig
124
+ trainer = Trainer(
125
+ model=model,
126
+ optimizer=optimizer,
127
+ config=config,
128
+ loss_fn="mse",
129
+ optimize_step=optimize_step
130
+ )
131
+
132
+ batch_size = 25
133
+ x = torch.linspace(0, 1, 32).reshape(-1, 1)
134
+ y = torch.sin(x)
135
+ train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True)
136
+ val_loader = to_dataloader(x, y, batch_size=batch_size, infinite=False)
137
+
138
+ # Train the model
139
+ model, optimizer = trainer.fit(train_loader, val_loader)
140
+ ```
141
+
142
+ This also supports both gradient based and gradient free optimization.
143
+ The default support is for gradient based optimization.
144
+
145
+ Notes:
146
+
147
+ - **set_use_grad()** (*class level*):This method is used to set the global `use_grad` flag,
148
+ controlling whether the trainer uses gradient-based optimization.
149
+ ```python
150
+ # gradient based
151
+ Trainer.set_use_grad(True)
152
+
153
+ # gradient free
154
+ Trainer.set_use_grad(False)
155
+ ```
156
+ - **Context Managers** (*instance level*): `enable_grad_opt()` and `disable_grad_opt()` are
157
+ context managers that temporarily switch the optimization mode for specific code blocks.
158
+ This is useful when you want to mix gradient-based and gradient-free optimization
159
+ in the same training process.
160
+ ```python
161
+ # gradient based
162
+ with trainer.enable_grad_opt(optimizer):
163
+ trainer.fit()
164
+
165
+ # gradient free
166
+ with trainer.disable_grad_opt(ng_optimizer):
167
+ trainer.fit()
168
+ ```
169
+
170
+ Examples
171
+
172
+ *Gradient based optimization example Usage*:
173
+ ```python
174
+ from torch import optim
175
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
176
+
177
+ Trainer.set_use_grad(True)
178
+ trainer = Trainer(
179
+ model=model,
180
+ optimizer=optimizer,
181
+ config=config,
182
+ loss_fn="mse"
183
+ )
184
+ trainer.fit(train_loader, val_loader)
185
+ ```
186
+ or
187
+ ```python
188
+ trainer = Trainer(
189
+ model=model,
190
+ config=config,
191
+ loss_fn="mse"
192
+ )
193
+ with trainer.enable_grad_opt(optimizer):
194
+ trainer.fit(train_loader, val_loader)
195
+ ```
196
+
197
+ *Gradient free optimization example Usage*:
198
+ ```python
199
+ import nevergrad as ng
200
+ from qadence.ml_tools.parameters import num_parameters
201
+ ng_optimizer = ng.optimizers.NGOpt(
202
+ budget=config.max_iter, parametrization= num_parameters(model)
203
+ )
204
+
205
+ Trainer.set_use_grad(False)
206
+ trainer = Trainer(
207
+ model=model,
208
+ optimizer=ng_optimizer,
209
+ config=config,
210
+ loss_fn="mse"
211
+ )
212
+ trainer.fit(train_loader, val_loader)
213
+ ```
214
+ or
215
+ ```python
216
+ import nevergrad as ng
217
+ from qadence.ml_tools.parameters import num_parameters
218
+ ng_optimizer = ng.optimizers.NGOpt(
219
+ budget=config.max_iter, parametrization= num_parameters(model)
220
+ )
221
+
222
+ trainer = Trainer(
223
+ model=model,
224
+ config=config,
225
+ loss_fn="mse"
226
+ )
227
+ with trainer.disable_grad_opt(ng_optimizer):
228
+ trainer.fit(train_loader, val_loader)
229
+ ```
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ model: nn.Module,
235
+ optimizer: optim.Optimizer | NGOptimizer | None,
236
+ config: TrainConfig,
237
+ loss_fn: str | Callable = "mse",
238
+ train_dataloader: DataLoader | DictDataLoader | None = None,
239
+ val_dataloader: DataLoader | DictDataLoader | None = None,
240
+ test_dataloader: DataLoader | DictDataLoader | None = None,
241
+ optimize_step: Callable = optimize_step,
242
+ device: torch_device | None = None,
243
+ dtype: torch_dtype | None = None,
244
+ max_batches: int | None = None,
245
+ ):
246
+ """
247
+ Initializes the Trainer class.
248
+
249
+ Args:
250
+ model (nn.Module): The PyTorch model to train.
251
+ optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
252
+ config (TrainConfig): Training configuration object.
253
+ loss_fn (str | Callable ): Loss function used for training.
254
+ If not specified, default mse loss will be used.
255
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
256
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
257
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
258
+ optimize_step (Callable): Function to execute an optimization step.
259
+ device (torch_device): Device to use for computation.
260
+ dtype (torch_dtype): Data type for computation.
261
+ max_batches (int | None): Maximum number of batches to process per epoch.
262
+ This is only valid in case of finite TensorDataset dataloaders.
263
+ if max_batches is not None, the maximum number of batches used will
264
+ be min(max_batches, len(dataloader.dataset))
265
+ In case of InfiniteTensorDataset only 1 batch per epoch is used.
266
+ """
267
+ super().__init__(
268
+ model=model,
269
+ optimizer=optimizer,
270
+ config=config,
271
+ loss_fn=loss_fn,
272
+ optimize_step=optimize_step,
273
+ train_dataloader=train_dataloader,
274
+ val_dataloader=val_dataloader,
275
+ test_dataloader=test_dataloader,
276
+ max_batches=max_batches,
277
+ )
278
+ self.current_epoch: int = 0
279
+ self.global_step: int = 0
280
+ self.log_device: str = "cpu" if device is None else device
281
+ self.device: torch_device | None = device
282
+ self.dtype: torch_dtype | None = dtype
283
+ self.data_dtype: torch_dtype | None = None
284
+ if self.dtype:
285
+ self.data_dtype = float64 if (self.dtype == complex128) else float32
286
+
287
+ def fit(
288
+ self,
289
+ train_dataloader: DataLoader | DictDataLoader | None = None,
290
+ val_dataloader: DataLoader | DictDataLoader | None = None,
291
+ ) -> tuple[nn.Module, optim.Optimizer]:
292
+ """
293
+ Fits the model using the specified training configuration.
294
+
295
+ The dataloaders can be provided to train on new datasets, or the default dataloaders
296
+ provided in the trainer will be used.
297
+
298
+ Args:
299
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
300
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
301
+
302
+ Returns:
303
+ tuple[nn.Module, optim.Optimizer]: The trained model and optimizer.
304
+ """
305
+ if train_dataloader is not None:
306
+ self.train_dataloader = train_dataloader
307
+ if val_dataloader is not None:
308
+ self.val_dataloader = val_dataloader
309
+
310
+ self._fit_setup()
311
+ self._train()
312
+ self._fit_end()
313
+ self.training_stage = TrainingStage("idle")
314
+ return self.model, self.optimizer
315
+
316
+ def _fit_setup(self) -> None:
317
+ """
318
+ Sets up the training environment, initializes configurations,.
319
+
320
+ and moves the model to the specified device and data type.
321
+ The callback_manager.start_training takes care of loading checkpoint,
322
+ and setting up the writer.
323
+ """
324
+ self.config_manager.initialize_config()
325
+ self.callback_manager.start_training(trainer=self)
326
+
327
+ # Move model to device
328
+ if isinstance(self.model, nn.DataParallel):
329
+ self.model = self.model.module.to(device=self.device, dtype=self.dtype)
330
+ else:
331
+ self.model = self.model.to(device=self.device, dtype=self.dtype)
332
+
333
+ # Progress bar for training visualization
334
+ self.progress: Progress = Progress(
335
+ TextColumn("[progress.description]{task.description}"),
336
+ BarColumn(),
337
+ TaskProgressColumn(),
338
+ TimeRemainingColumn(elapsed_when_finished=True),
339
+ )
340
+
341
+ # Quick Fix for iteration 0
342
+ self._reset_model_and_opt()
343
+
344
+ # Run validation at the start if specified in the configuration
345
+ self.perform_val = self.config.val_every > 0
346
+ if self.perform_val:
347
+ self.run_validation(self.val_dataloader)
348
+
349
+ def _fit_end(self) -> None:
350
+ """Finalizes the training and closes the writer."""
351
+ self.callback_manager.end_training(trainer=self)
352
+
353
+ @BaseTrainer.callback("train")
354
+ def _train(self) -> list[list[tuple[torch.Tensor, dict[str, Any]]]]:
355
+ """
356
+ Runs the main training loop, iterating over epochs.
357
+
358
+ Returns:
359
+ list[list[tuple[torch.Tensor, dict[str, Any]]]]: Training loss
360
+ metrics for all epochs.
361
+ list -> list -> tuples
362
+ Epochs -> Training Batches -> (loss, metrics)
363
+ """
364
+ self.on_train_start()
365
+ train_losses = []
366
+ val_losses = []
367
+
368
+ with self.progress:
369
+ train_task = self.progress.add_task(
370
+ "Training", total=self.config_manager.config.max_iter
371
+ )
372
+ if self.perform_val:
373
+ val_task = self.progress.add_task(
374
+ "Validation",
375
+ total=(self.config_manager.config.max_iter + 1) / self.config.val_every,
376
+ )
377
+ for epoch in range(
378
+ self.global_step, self.global_step + self.config_manager.config.max_iter + 1
379
+ ):
380
+ try:
381
+ self.current_epoch = epoch
382
+ self.on_train_epoch_start()
383
+ train_epoch_loss_metrics = self.run_training(self.train_dataloader)
384
+ train_losses.append(train_epoch_loss_metrics)
385
+ self.on_train_epoch_end(train_epoch_loss_metrics)
386
+
387
+ # Run validation periodically if specified
388
+ if self.perform_val and self.current_epoch % self.config.val_every == 0:
389
+ self.on_val_epoch_start()
390
+ val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
391
+ val_losses.append(val_epoch_loss_metrics)
392
+ self.on_val_epoch_end(val_epoch_loss_metrics)
393
+ self.progress.update(val_task, advance=1)
394
+
395
+ self.progress.update(train_task, advance=1)
396
+ except KeyboardInterrupt:
397
+ logger.info("Terminating training gracefully after the current iteration.")
398
+ break
399
+
400
+ self.on_train_end(train_losses, val_losses)
401
+ return train_losses
402
+
403
+ @BaseTrainer.callback("train_epoch")
404
+ def run_training(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
405
+ """
406
+ Runs the training for a single epoch, iterating over multiple batches.
407
+
408
+ Args:
409
+ dataloader (DataLoader): DataLoader for training data.
410
+
411
+ Returns:
412
+ list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch.
413
+ list -> tuples
414
+ Training Batches -> (loss, metrics)
415
+ """
416
+ self.model.train()
417
+ train_epoch_loss_metrics = []
418
+ # Quick Fix for iteration 0
419
+ self._reset_model_and_opt()
420
+
421
+ for batch in self._batch_iter(dataloader, self.num_training_batches):
422
+ self.on_train_batch_start(batch)
423
+ train_batch_loss_metrics = self.run_train_batch(batch)
424
+ train_epoch_loss_metrics.append(train_batch_loss_metrics)
425
+ self.on_train_batch_end(train_batch_loss_metrics)
426
+
427
+ return train_epoch_loss_metrics
428
+
429
+ @BaseTrainer.callback("train_batch")
430
+ def run_train_batch(
431
+ self, batch: tuple[torch.Tensor, ...]
432
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
433
+ """
434
+ Runs a single training batch, performing optimization.
435
+
436
+ We use the step function to optimize the model based on use_grad.
437
+ use_grad = True entails gradient based optimization, for which we use
438
+ optimize_step function.
439
+ use_grad = False entails gradient free optimization, for which we use
440
+ update_ng_parameters function.
441
+
442
+ Args:
443
+ batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader.
444
+
445
+ Returns:
446
+ tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch.
447
+ tuple of (loss, metrics)
448
+ """
449
+
450
+ if self.use_grad:
451
+ # Perform gradient-based optimization
452
+ loss_metrics = self.optimize_step(
453
+ model=self.model,
454
+ optimizer=self.optimizer,
455
+ loss_fn=self.loss_fn,
456
+ xs=batch,
457
+ device=self.device,
458
+ dtype=self.data_dtype,
459
+ )
460
+ else:
461
+ # Perform optimization using Nevergrad
462
+ loss, metrics, ng_params = update_ng_parameters(
463
+ model=self.model,
464
+ optimizer=self.optimizer,
465
+ loss_fn=self.loss_fn,
466
+ data=batch,
467
+ ng_params=self.ng_params, # type: ignore[arg-type]
468
+ )
469
+ self.ng_params = ng_params
470
+ loss_metrics = loss, metrics
471
+
472
+ return self._modify_batch_end_loss_metrics(loss_metrics)
473
+
474
+ @BaseTrainer.callback("val_epoch")
475
+ def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
476
+ """
477
+ Runs the validation loop for a single epoch, iterating over multiple batches.
478
+
479
+ Args:
480
+ dataloader (DataLoader): DataLoader for validation data.
481
+
482
+ Returns:
483
+ list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch.
484
+ list -> tuples
485
+ Validation Batches -> (loss, metrics)
486
+ """
487
+ self.model.eval()
488
+ val_epoch_loss_metrics = []
489
+
490
+ for batch in self._batch_iter(dataloader, self.num_validation_batches):
491
+ self.on_val_batch_start(batch)
492
+ val_batch_loss_metrics = self.run_val_batch(batch)
493
+ val_epoch_loss_metrics.append(val_batch_loss_metrics)
494
+ self.on_val_batch_end(val_batch_loss_metrics)
495
+
496
+ return val_epoch_loss_metrics
497
+
498
+ @BaseTrainer.callback("val_batch")
499
+ def run_val_batch(self, batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, dict[str, Any]]:
500
+ """
501
+ Runs a single validation batch.
502
+
503
+ Args:
504
+ batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader.
505
+
506
+ Returns:
507
+ tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch.
508
+ """
509
+ with torch.no_grad():
510
+ loss_metrics = self.loss_fn(self.model, batch)
511
+ return self._modify_batch_end_loss_metrics(loss_metrics)
512
+
513
+ def test(self, test_dataloader: DataLoader = None) -> list[tuple[torch.Tensor, dict[str, Any]]]:
514
+ """
515
+ Runs the testing loop if a test DataLoader is provided.
516
+
517
+ if the test_dataloader is not provided, default test_dataloader defined
518
+ in the Trainer class is used.
519
+
520
+ Args:
521
+ test_dataloader (DataLoader): DataLoader for test data.
522
+
523
+ Returns:
524
+ list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch.
525
+ list -> tuples
526
+ Test Batches -> (loss, metrics)
527
+ """
528
+ if test_dataloader is not None:
529
+ self.test_dataloader = test_dataloader
530
+
531
+ self.model.eval()
532
+ test_loss_metrics = []
533
+
534
+ for batch in self._batch_iter(test_dataloader, self.num_training_batches):
535
+ self.on_test_batch_start(batch)
536
+ loss_metrics = self.run_test_batch(batch)
537
+ test_loss_metrics.append(loss_metrics)
538
+ self.on_test_batch_end(loss_metrics)
539
+
540
+ return test_loss_metrics
541
+
542
+ @BaseTrainer.callback("test_batch")
543
+ def run_test_batch(
544
+ self, batch: tuple[torch.Tensor, ...]
545
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
546
+ """
547
+ Runs a single test batch.
548
+
549
+ Args:
550
+ batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader.
551
+
552
+ Returns:
553
+ tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch.
554
+ """
555
+ with torch.no_grad():
556
+ loss_metrics = self.loss_fn(self.model, batch)
557
+ return self._modify_batch_end_loss_metrics(loss_metrics)
558
+
559
+ def _batch_iter(
560
+ self,
561
+ dataloader: DataLoader | DictDataLoader,
562
+ num_batches: int,
563
+ ) -> Iterable[tuple[torch.Tensor, ...] | None]:
564
+ """
565
+ Yields batches from the provided dataloader.
566
+
567
+ Args:
568
+ dataloader ([DataLoader]): The dataloader to iterate over.
569
+ num_batches (int): The maximum number of batches to yield.
570
+
571
+ Yields:
572
+ Iterable[tuple[torch.Tensor, ...] | None]: A batch from the dataloader moved to the
573
+ specified device and dtype.
574
+ """
575
+ if dataloader is None:
576
+ for _ in range(num_batches):
577
+ yield None
578
+ else:
579
+ for batch in islice(dataloader, num_batches):
580
+ # batch is moved to device inside optimize step
581
+ # batch = data_to_device(batch, device=self.device, dtype=self.data_dtype)
582
+ yield batch
583
+
584
+ def _modify_batch_end_loss_metrics(
585
+ self, loss_metrics: tuple[torch.Tensor, dict[str, Any]]
586
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
587
+ """
588
+ Modifies the loss and metrics at the end of batch for proper logging.
589
+
590
+ All metrics are prefixed with the proper state of the training process
591
+ - "train_" or "val_" or "test_"
592
+ A "{state}_loss" is added to metrics.
593
+
594
+ Args:
595
+ loss_metrics (tuple[torch.Tensor, dict[str, Any]]): Original loss and metrics.
596
+
597
+ Returns:
598
+ tuple[None | torch.Tensor, dict[str, Any]]: Modified loss and metrics.
599
+ """
600
+ for phase in ["train", "val", "test"]:
601
+ if phase in self.training_stage:
602
+ loss, metrics = loss_metrics
603
+ updated_metrics = {f"{phase}_{key}": value for key, value in metrics.items()}
604
+ updated_metrics[f"{phase}_loss"] = loss
605
+ return loss, updated_metrics
606
+ return loss_metrics
607
+
608
+ def _reset_model_and_opt(self) -> None:
609
+ """
610
+ Save model_old and optimizer_old for epoch 0.
611
+
612
+ This allows us to create a copy of model
613
+ and optimizer before running the optimization.
614
+
615
+ We do this because optimize step provides loss, metrics
616
+ before step of optimization
617
+ To align them with model/optimizer correctly, we checkpoint
618
+ the older copy of the model.
619
+ """
620
+
621
+ # TODO: review optimize_step to provide iteration aligned model and loss.
622
+ try:
623
+ # Deep copy model and optimizer to maintain checkpoints
624
+ self.model_old = copy.deepcopy(self.model)
625
+ self.optimizer_old = copy.deepcopy(self.optimizer)
626
+ except Exception:
627
+ self.model_old = self.model
628
+ self.optimizer_old = self.optimizer
629
+
630
+ def build_optimize_result(
631
+ self,
632
+ result: None
633
+ | tuple[torch.Tensor, dict[Any, Any]]
634
+ | list[tuple[torch.Tensor, dict[Any, Any]]]
635
+ | list[list[tuple[torch.Tensor, dict[Any, Any]]]],
636
+ ) -> None:
637
+ """
638
+ Builds and stores the optimization result by calculating the average loss and metrics.
639
+
640
+ Result (or loss_metrics) can have multiple formats:
641
+ - `None` Indicates no loss or metrics data is provided.
642
+ - `tuple[torch.Tensor, dict[str, Any]]` A single tuple containing the loss tensor
643
+ and metrics dictionary - at the end of batch.
644
+ - `list[tuple[torch.Tensor, dict[str, Any]]]` A list of tuples for
645
+ multiple batches.
646
+ - `list[list[tuple[torch.Tensor, dict[str, Any]]]]` A list of lists of tuples,
647
+ where each inner list represents metrics across multiple batches within an epoch.
648
+
649
+ Args:
650
+ result: (None |
651
+ tuple[torch.Tensor, dict[Any, Any]] |
652
+ list[tuple[torch.Tensor, dict[Any, Any]]] |
653
+ list[list[tuple[torch.Tensor, dict[Any, Any]]]])
654
+ The loss and metrics data, which can have multiple formats
655
+
656
+ Returns:
657
+ None: This method does not return anything. It sets `self.opt_result` with
658
+ the computed average loss and metrics.
659
+ """
660
+ loss_metrics = result
661
+ if loss_metrics is None:
662
+ loss = None
663
+ metrics: dict[Any, Any] = {}
664
+ elif isinstance(loss_metrics, tuple):
665
+ # Single tuple case
666
+ loss, metrics = loss_metrics
667
+ else:
668
+ last_epoch: list[tuple[torch.Tensor, dict[Any, Any]]] = []
669
+ if isinstance(loss_metrics, list):
670
+ # Check if it's a list of tuples
671
+ if all(isinstance(item, tuple) for item in loss_metrics):
672
+ last_epoch = cast(list[tuple[torch.Tensor, dict[Any, Any]]], loss_metrics)
673
+ # Check if it's a list of lists of tuples
674
+ elif all(isinstance(item, list) for item in loss_metrics):
675
+ last_epoch = cast(
676
+ list[tuple[torch.Tensor, dict[Any, Any]]],
677
+ loss_metrics[-1] if loss_metrics else [],
678
+ )
679
+ else:
680
+ raise ValueError(
681
+ "Invalid format for result: Expected None, tuple, list of tuples,"
682
+ " or list of lists of tuples."
683
+ )
684
+
685
+ if not last_epoch:
686
+ loss, metrics = None, {}
687
+ else:
688
+ # Compute the average loss over the batches
689
+ loss_tensor = torch.stack([loss_batch for loss_batch, _ in last_epoch])
690
+ avg_loss = loss_tensor.mean()
691
+
692
+ # Collect and average metrics for all batches
693
+ metric_keys = last_epoch[0][1].keys()
694
+ metrics_stacked: dict = {key: [] for key in metric_keys}
695
+
696
+ for _, metrics_batch in last_epoch:
697
+ for key in metric_keys:
698
+ value = metrics_batch[key]
699
+ metrics_stacked[key].append(value)
700
+
701
+ avg_metrics = {key: torch.stack(metrics_stacked[key]).mean() for key in metric_keys}
702
+
703
+ loss, metrics = avg_loss, avg_metrics
704
+
705
+ # Store the optimization result
706
+ self.opt_result = OptimizeResult(
707
+ self.current_epoch, self.model_old, self.optimizer_old, loss, metrics
708
+ )