qadence 1.8.0__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -1
- qadence/analog/parse_analog.py +1 -2
- qadence/backends/gpsr.py +8 -2
- qadence/backends/pulser/backend.py +7 -23
- qadence/backends/pyqtorch/backend.py +80 -5
- qadence/backends/pyqtorch/config.py +10 -3
- qadence/backends/pyqtorch/convert_ops.py +63 -2
- qadence/blocks/primitive.py +1 -0
- qadence/execution.py +0 -2
- qadence/log_config.yaml +10 -0
- qadence/measurements/shadow.py +97 -128
- qadence/measurements/utils.py +2 -2
- qadence/mitigations/readout.py +12 -6
- qadence/ml_tools/__init__.py +4 -8
- qadence/ml_tools/callbacks/__init__.py +30 -0
- qadence/ml_tools/callbacks/callback.py +451 -0
- qadence/ml_tools/callbacks/callbackmanager.py +214 -0
- qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
- qadence/ml_tools/callbacks/writer_registry.py +441 -0
- qadence/ml_tools/config.py +132 -258
- qadence/ml_tools/data.py +7 -3
- qadence/ml_tools/loss/__init__.py +10 -0
- qadence/ml_tools/loss/loss.py +87 -0
- qadence/ml_tools/optimize_step.py +45 -10
- qadence/ml_tools/stages.py +46 -0
- qadence/ml_tools/train_utils/__init__.py +7 -0
- qadence/ml_tools/train_utils/base_trainer.py +555 -0
- qadence/ml_tools/train_utils/config_manager.py +184 -0
- qadence/ml_tools/trainer.py +708 -0
- qadence/model.py +1 -1
- qadence/noise/__init__.py +2 -2
- qadence/noise/protocols.py +18 -53
- qadence/operations/ham_evo.py +87 -26
- qadence/transpile/noise.py +12 -5
- qadence/types.py +15 -3
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/METADATA +3 -4
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/RECORD +39 -32
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/WHEEL +1 -1
- qadence/ml_tools/printing.py +0 -154
- qadence/ml_tools/train_grad.py +0 -395
- qadence/ml_tools/train_no_grad.py +0 -199
- qadence/noise/readout.py +0 -218
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,46 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from qadence.types import StrEnum
|
4
|
+
|
5
|
+
|
6
|
+
class TrainingStage(StrEnum):
|
7
|
+
"""Different stages in the training, validation, and testing process."""
|
8
|
+
|
9
|
+
IDLE = "idle"
|
10
|
+
"""An 'idle' stage for scenarios where no training, validation, or testing is involved."""
|
11
|
+
|
12
|
+
TRAIN_START = "train_start"
|
13
|
+
"""Marks the start of the training process."""
|
14
|
+
|
15
|
+
TRAIN_END = "train_end"
|
16
|
+
"""Marks the end of the training process."""
|
17
|
+
|
18
|
+
TRAIN_EPOCH_START = "train_epoch_start"
|
19
|
+
"""Indicates the start of a training epoch."""
|
20
|
+
|
21
|
+
TRAIN_EPOCH_END = "train_epoch_end"
|
22
|
+
"""Indicates the end of a training epoch."""
|
23
|
+
|
24
|
+
TRAIN_BATCH_START = "train_batch_start"
|
25
|
+
"""Marks the start of processing a training batch."""
|
26
|
+
|
27
|
+
TRAIN_BATCH_END = "train_batch_end"
|
28
|
+
"""Marks the end of processing a training batch."""
|
29
|
+
|
30
|
+
VAL_EPOCH_START = "val_epoch_start"
|
31
|
+
"""Indicates the start of a validation epoch."""
|
32
|
+
|
33
|
+
VAL_EPOCH_END = "val_epoch_end"
|
34
|
+
"""Indicates the end of a validation epoch."""
|
35
|
+
|
36
|
+
VAL_BATCH_START = "val_batch_start"
|
37
|
+
"""Marks the start of processing a validation batch."""
|
38
|
+
|
39
|
+
VAL_BATCH_END = "val_batch_end"
|
40
|
+
"""Marks the end of processing a validation batch."""
|
41
|
+
|
42
|
+
TEST_BATCH_START = "test_batch_start"
|
43
|
+
"""Marks the start of processing a test batch."""
|
44
|
+
|
45
|
+
TEST_BATCH_END = "test_batch_end"
|
46
|
+
"""Marks the end of processing a test batch."""
|
@@ -0,0 +1,555 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from logging import getLogger
|
5
|
+
from typing import Any, Callable, Iterator
|
6
|
+
|
7
|
+
import nevergrad as ng
|
8
|
+
import torch
|
9
|
+
from nevergrad.optimization.base import Optimizer as NGOptimizer
|
10
|
+
from torch import nn, optim
|
11
|
+
from torch.utils.data import DataLoader, TensorDataset
|
12
|
+
|
13
|
+
from qadence.ml_tools.callbacks import CallbacksManager
|
14
|
+
from qadence.ml_tools.config import TrainConfig
|
15
|
+
from qadence.ml_tools.data import DictDataLoader
|
16
|
+
from qadence.ml_tools.loss import get_loss_fn
|
17
|
+
from qadence.ml_tools.optimize_step import optimize_step
|
18
|
+
from qadence.ml_tools.parameters import get_parameters
|
19
|
+
from qadence.ml_tools.stages import TrainingStage
|
20
|
+
|
21
|
+
from .config_manager import ConfigManager
|
22
|
+
|
23
|
+
logger = getLogger("ml_tools")
|
24
|
+
|
25
|
+
|
26
|
+
class BaseTrainer:
|
27
|
+
"""Base class for training machine learning models using a given optimizer.
|
28
|
+
|
29
|
+
The base class implements contextmanager for gradient based/free optimization,
|
30
|
+
properties, property setters, input validations, callback decorator generator,
|
31
|
+
and empty hooks for different training steps.
|
32
|
+
|
33
|
+
This class provides:
|
34
|
+
- Context managers for enabling/disabling gradient-based optimization
|
35
|
+
- Properties for managing models, optimizers, and dataloaders
|
36
|
+
- Input validations and a callback decorator generator
|
37
|
+
- Config and callback managers using the provided `TrainConfig`
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
use_grad (bool): Indicates if gradients are used for optimization. Default is True.
|
41
|
+
|
42
|
+
model (nn.Module): The neural network model.
|
43
|
+
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
|
44
|
+
config (TrainConfig): The configuration settings for training.
|
45
|
+
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
|
46
|
+
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
|
47
|
+
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
|
48
|
+
|
49
|
+
optimize_step (Callable): Function for performing an optimization step.
|
50
|
+
loss_fn (Callable | str ]): loss function to use. Default loss function
|
51
|
+
used is 'mse'
|
52
|
+
|
53
|
+
num_training_batches (int): Number of training batches. In case of
|
54
|
+
InfiniteTensorDataset only 1 batch per epoch is used.
|
55
|
+
num_validation_batches (int): Number of validation batches. In case of
|
56
|
+
InfiniteTensorDataset only 1 batch per epoch is used.
|
57
|
+
num_test_batches (int): Number of test batches. In case of
|
58
|
+
InfiniteTensorDataset only 1 batch per epoch is used.
|
59
|
+
|
60
|
+
state (str): Current state in the training process
|
61
|
+
"""
|
62
|
+
|
63
|
+
_use_grad: bool = True
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
model: nn.Module,
|
68
|
+
optimizer: optim.Optimizer | NGOptimizer | None,
|
69
|
+
config: TrainConfig,
|
70
|
+
loss_fn: str | Callable = "mse",
|
71
|
+
optimize_step: Callable = optimize_step,
|
72
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
73
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
74
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
75
|
+
max_batches: int | None = None,
|
76
|
+
):
|
77
|
+
"""
|
78
|
+
Initializes the BaseTrainer.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
model (nn.Module): The model to train.
|
82
|
+
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer
|
83
|
+
for training.
|
84
|
+
config (TrainConfig): The TrainConfig settings for training.
|
85
|
+
loss_fn (str | Callable): The loss function to use.
|
86
|
+
str input to be specified to use a default loss function.
|
87
|
+
currently supported loss functions: 'mse', 'cross_entropy'.
|
88
|
+
If not specified, default mse loss will be used.
|
89
|
+
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
|
90
|
+
If the model does not need data to evaluate loss, no dataset
|
91
|
+
should be provided.
|
92
|
+
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
|
93
|
+
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
|
94
|
+
max_batches (int | None): Maximum number of batches to process per epoch.
|
95
|
+
This is only valid in case of finite TensorDataset dataloaders.
|
96
|
+
if max_batches is not None, the maximum number of batches used will
|
97
|
+
be min(max_batches, len(dataloader.dataset))
|
98
|
+
In case of InfiniteTensorDataset only 1 batch per epoch is used.
|
99
|
+
"""
|
100
|
+
self._model: nn.Module
|
101
|
+
self._optimizer: optim.Optimizer | NGOptimizer | None
|
102
|
+
self._config: TrainConfig
|
103
|
+
self._train_dataloader: DataLoader | DictDataLoader | None = None
|
104
|
+
self._val_dataloader: DataLoader | DictDataLoader | None = None
|
105
|
+
self._test_dataloader: DataLoader | DictDataLoader | None = None
|
106
|
+
|
107
|
+
self.config = config
|
108
|
+
self.model = model
|
109
|
+
self.optimizer = optimizer
|
110
|
+
self.max_batches = max_batches
|
111
|
+
|
112
|
+
self.num_training_batches: int
|
113
|
+
self.num_validation_batches: int
|
114
|
+
self.num_test_batches: int
|
115
|
+
|
116
|
+
self.train_dataloader = train_dataloader
|
117
|
+
self.val_dataloader = val_dataloader
|
118
|
+
self.test_dataloader = test_dataloader
|
119
|
+
|
120
|
+
self.loss_fn: Callable = get_loss_fn(loss_fn)
|
121
|
+
self.optimize_step: Callable = optimize_step
|
122
|
+
self.ng_params: ng.p.Array
|
123
|
+
self.training_stage: TrainingStage = TrainingStage("idle")
|
124
|
+
|
125
|
+
@property
|
126
|
+
def use_grad(self) -> bool:
|
127
|
+
"""
|
128
|
+
Returns the optimization framework for the trainer.
|
129
|
+
|
130
|
+
use_grad = True : Gradient based optimization
|
131
|
+
use_grad = False : Gradient free optimization
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
bool: Bool value for using gradient.
|
135
|
+
"""
|
136
|
+
return self._use_grad
|
137
|
+
|
138
|
+
@use_grad.setter
|
139
|
+
def use_grad(self, use_grad: bool) -> None:
|
140
|
+
"""
|
141
|
+
Returns the optimization framework for the trainer.
|
142
|
+
|
143
|
+
use_grad = True : Gradient based optimization
|
144
|
+
use_grad = False : Gradient free optimization
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
bool: Bool value for using gradient.
|
148
|
+
"""
|
149
|
+
if not isinstance(use_grad, bool):
|
150
|
+
raise TypeError("use_grad must be an True or False.")
|
151
|
+
self._use_grad = use_grad
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def set_use_grad(cls, value: bool) -> None:
|
155
|
+
"""
|
156
|
+
Sets the global use_grad flag.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
value (bool): Whether to use gradient-based optimization.
|
160
|
+
"""
|
161
|
+
if not isinstance(value, bool):
|
162
|
+
raise TypeError("use_grad must be a boolean value.")
|
163
|
+
cls._use_grad = value
|
164
|
+
|
165
|
+
@property
|
166
|
+
def model(self) -> nn.Module:
|
167
|
+
"""
|
168
|
+
Returns the model if set, otherwise raises an error.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
nn.Module: The model.
|
172
|
+
"""
|
173
|
+
if self._model is None:
|
174
|
+
raise ValueError("Model has not been set.")
|
175
|
+
return self._model
|
176
|
+
|
177
|
+
@model.setter
|
178
|
+
def model(self, model: nn.Module) -> None:
|
179
|
+
"""
|
180
|
+
Sets the model, ensuring it is an instance of nn.Module.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
model (nn.Module): The neural network model.
|
184
|
+
"""
|
185
|
+
if model is not None and not isinstance(model, nn.Module):
|
186
|
+
raise TypeError("model must be an instance of nn.Module or None.")
|
187
|
+
self._model = model
|
188
|
+
|
189
|
+
@property
|
190
|
+
def optimizer(self) -> optim.Optimizer | NGOptimizer | None:
|
191
|
+
"""
|
192
|
+
Returns the optimizer if set, otherwise raises an error.
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
optim.Optimizer | NGOptimizer | None: The optimizer.
|
196
|
+
"""
|
197
|
+
return self._optimizer
|
198
|
+
|
199
|
+
@optimizer.setter
|
200
|
+
def optimizer(self, optimizer: optim.Optimizer | NGOptimizer | None) -> None:
|
201
|
+
"""
|
202
|
+
Sets the optimizer, checking compatibility with gradient use.
|
203
|
+
|
204
|
+
We also set up the budget/behavior of different optimizers here.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
|
208
|
+
"""
|
209
|
+
if optimizer is not None:
|
210
|
+
if self.use_grad:
|
211
|
+
if not isinstance(optimizer, optim.Optimizer):
|
212
|
+
raise TypeError("use_grad=True requires a PyTorch optimizer instance.")
|
213
|
+
else:
|
214
|
+
if not isinstance(optimizer, NGOptimizer):
|
215
|
+
raise TypeError("use_grad=False requires a Nevergrad optimizer instance.")
|
216
|
+
else:
|
217
|
+
optimizer.budget = self.config.max_iter
|
218
|
+
optimizer.enable_pickling()
|
219
|
+
params = get_parameters(self.model).detach().numpy()
|
220
|
+
self.ng_params = ng.p.Array(init=params)
|
221
|
+
|
222
|
+
self._optimizer = optimizer
|
223
|
+
|
224
|
+
@property
|
225
|
+
def train_dataloader(self) -> DataLoader:
|
226
|
+
"""
|
227
|
+
Returns the training DataLoader, validating its type.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
DataLoader: The DataLoader for training data.
|
231
|
+
"""
|
232
|
+
return self._train_dataloader
|
233
|
+
|
234
|
+
@train_dataloader.setter
|
235
|
+
def train_dataloader(self, dataloader: DataLoader) -> None:
|
236
|
+
"""
|
237
|
+
Sets the training DataLoader and computes the number of batches.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
dataloader (DataLoader): The DataLoader for training data.
|
241
|
+
"""
|
242
|
+
self._validate_dataloader(dataloader, "train")
|
243
|
+
self._train_dataloader = dataloader
|
244
|
+
self.num_training_batches = self._compute_num_batches(dataloader)
|
245
|
+
|
246
|
+
@property
|
247
|
+
def val_dataloader(self) -> DataLoader:
|
248
|
+
"""
|
249
|
+
Returns the validation DataLoader, validating its type.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
DataLoader: The DataLoader for validation data.
|
253
|
+
"""
|
254
|
+
return self._val_dataloader
|
255
|
+
|
256
|
+
@val_dataloader.setter
|
257
|
+
def val_dataloader(self, dataloader: DataLoader) -> None:
|
258
|
+
"""
|
259
|
+
Sets the validation DataLoader and computes the number of batches.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
dataloader (DataLoader): The DataLoader for validation data.
|
263
|
+
"""
|
264
|
+
self._validate_dataloader(dataloader, "val")
|
265
|
+
self._val_dataloader = dataloader
|
266
|
+
self.num_validation_batches = self._compute_num_batches(dataloader)
|
267
|
+
|
268
|
+
@property
|
269
|
+
def test_dataloader(self) -> DataLoader:
|
270
|
+
"""
|
271
|
+
Returns the test DataLoader, validating its type.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
DataLoader: The DataLoader for testing data.
|
275
|
+
"""
|
276
|
+
return self._test_dataloader
|
277
|
+
|
278
|
+
@test_dataloader.setter
|
279
|
+
def test_dataloader(self, dataloader: DataLoader) -> None:
|
280
|
+
"""
|
281
|
+
Sets the test DataLoader and computes the number of batches.
|
282
|
+
|
283
|
+
Args:
|
284
|
+
dataloader (DataLoader): The DataLoader for testing data.
|
285
|
+
"""
|
286
|
+
self._validate_dataloader(dataloader, "test")
|
287
|
+
self._test_dataloader = dataloader
|
288
|
+
self.num_test_batches = self._compute_num_batches(dataloader)
|
289
|
+
|
290
|
+
@property
|
291
|
+
def config(self) -> TrainConfig:
|
292
|
+
"""
|
293
|
+
Returns the training configuration.
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
TrainConfig: The configuration object.
|
297
|
+
"""
|
298
|
+
return self._config
|
299
|
+
|
300
|
+
@config.setter
|
301
|
+
def config(self, value: TrainConfig) -> None:
|
302
|
+
"""
|
303
|
+
Sets the training configuration and initializes callback and config managers.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
value (TrainConfig): The configuration object.
|
307
|
+
"""
|
308
|
+
if value and not isinstance(value, TrainConfig):
|
309
|
+
raise TypeError("config must be an instance of TrainConfig.")
|
310
|
+
self._config = value
|
311
|
+
self.callback_manager = CallbacksManager(value)
|
312
|
+
self.config_manager = ConfigManager(value)
|
313
|
+
|
314
|
+
def _compute_num_batches(self, dataloader: DataLoader | DictDataLoader) -> int:
|
315
|
+
"""
|
316
|
+
Computes the number of batches for the given DataLoader.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
dataloader (DataLoader): The DataLoader for which to compute
|
320
|
+
the number of batches.
|
321
|
+
"""
|
322
|
+
if dataloader is None:
|
323
|
+
return 1
|
324
|
+
if isinstance(dataloader, DictDataLoader):
|
325
|
+
dataloader_name, dataloader_value = list(dataloader.dataloaders.items())[0]
|
326
|
+
dataset = dataloader_value.dataset
|
327
|
+
batch_size = dataloader_value.batch_size
|
328
|
+
else:
|
329
|
+
dataset = dataloader.dataset
|
330
|
+
batch_size = dataloader.batch_size
|
331
|
+
|
332
|
+
if isinstance(dataset, TensorDataset):
|
333
|
+
n_batches = int((dataset.tensors[0].size(0) + batch_size - 1) // batch_size)
|
334
|
+
return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
|
335
|
+
else:
|
336
|
+
return 1
|
337
|
+
|
338
|
+
def _validate_dataloader(
|
339
|
+
self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
|
340
|
+
) -> None:
|
341
|
+
"""
|
342
|
+
Validates the type of the DataLoader and raises errors for unsupported types.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
|
346
|
+
dataloader_type (str): The type of DataLoader ("train", "val", or "test").
|
347
|
+
"""
|
348
|
+
if dataloader is not None:
|
349
|
+
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
|
350
|
+
raise NotImplementedError(
|
351
|
+
f"Unsupported dataloader type: {type(dataloader)}."
|
352
|
+
"The dataloader must be an instance of DataLoader."
|
353
|
+
)
|
354
|
+
if dataloader_type == "val" and self.config.val_every > 0:
|
355
|
+
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
|
356
|
+
raise ValueError(
|
357
|
+
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
|
358
|
+
"must be an instance of `DataLoader` or `DictDataLoader`."
|
359
|
+
)
|
360
|
+
|
361
|
+
@staticmethod
|
362
|
+
def callback(phase: str) -> Callable:
|
363
|
+
"""
|
364
|
+
Decorator for executing callbacks before and after a phase.
|
365
|
+
|
366
|
+
Phase are different hooks during the training. list of valid
|
367
|
+
phases is defined in Callbacks.
|
368
|
+
We also update the current state of the training process in
|
369
|
+
the callback decorator.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
phase (str): The phase for which the callback is executed (e.g., "train",
|
373
|
+
"train_epoch", "train_batch").
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
Callable: The decorated function.
|
377
|
+
"""
|
378
|
+
|
379
|
+
def decorator(method: Callable) -> Callable:
|
380
|
+
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
381
|
+
start_event = f"{phase}_start"
|
382
|
+
end_event = f"{phase}_end"
|
383
|
+
|
384
|
+
self.training_stage = TrainingStage(start_event)
|
385
|
+
self.callback_manager.run_callbacks(trainer=self)
|
386
|
+
result = method(self, *args, **kwargs)
|
387
|
+
|
388
|
+
self.training_stage = TrainingStage(end_event)
|
389
|
+
# build_optimize_result method is defined in the trainer.
|
390
|
+
self.build_optimize_result(result)
|
391
|
+
self.callback_manager.run_callbacks(trainer=self)
|
392
|
+
|
393
|
+
return result
|
394
|
+
|
395
|
+
return wrapper
|
396
|
+
|
397
|
+
return decorator
|
398
|
+
|
399
|
+
@contextmanager
|
400
|
+
def enable_grad_opt(self, optimizer: optim.Optimizer | None = None) -> Iterator[None]:
|
401
|
+
"""
|
402
|
+
Context manager to temporarily enable gradient-based optimization.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
optimizer (optim.Optimizer): The PyTorch optimizer to use.
|
406
|
+
If no optimizer is provided, default optimizer for trainer
|
407
|
+
object will be used.
|
408
|
+
"""
|
409
|
+
original_mode = self.use_grad
|
410
|
+
original_optimizer = self._optimizer
|
411
|
+
try:
|
412
|
+
self.use_grad = True
|
413
|
+
self.callback_manager.use_grad = True
|
414
|
+
self.optimizer = optimizer if optimizer else self.optimizer
|
415
|
+
yield
|
416
|
+
finally:
|
417
|
+
self.use_grad = original_mode
|
418
|
+
self.callback_manager.use_grad = original_mode
|
419
|
+
self.optimizer = original_optimizer
|
420
|
+
|
421
|
+
@contextmanager
|
422
|
+
def disable_grad_opt(self, optimizer: NGOptimizer | None = None) -> Iterator[None]:
|
423
|
+
"""
|
424
|
+
Context manager to temporarily disable gradient-based optimization.
|
425
|
+
|
426
|
+
Args:
|
427
|
+
optimizer (NGOptimizer): The Nevergrad optimizer to use.
|
428
|
+
If no optimizer is provided, default optimizer for trainer
|
429
|
+
object will be used.
|
430
|
+
"""
|
431
|
+
original_mode = self.use_grad
|
432
|
+
original_optimizer = self._optimizer
|
433
|
+
try:
|
434
|
+
self.use_grad = False
|
435
|
+
self.callback_manager.use_grad = False
|
436
|
+
self.optimizer = optimizer if optimizer else self.optimizer
|
437
|
+
yield
|
438
|
+
finally:
|
439
|
+
self.use_grad = original_mode
|
440
|
+
self.callback_manager.use_grad = original_mode
|
441
|
+
self.optimizer = original_optimizer
|
442
|
+
|
443
|
+
def on_train_start(self) -> None:
|
444
|
+
"""Called at the start of training."""
|
445
|
+
pass
|
446
|
+
|
447
|
+
def on_train_end(
|
448
|
+
self,
|
449
|
+
train_losses: list[list[tuple[torch.Tensor, Any]]],
|
450
|
+
val_losses: list[list[tuple[torch.Tensor, Any]]] | None = None,
|
451
|
+
) -> None:
|
452
|
+
"""
|
453
|
+
Called at the end of training.
|
454
|
+
|
455
|
+
Args:
|
456
|
+
train_losses (list[list[tuple[torch.Tensor, Any]]]):
|
457
|
+
Metrics for the training losses.
|
458
|
+
list -> list -> tuples
|
459
|
+
Epochs -> Training Batches -> (loss, metrics)
|
460
|
+
val_losses (list[list[tuple[torch.Tensor, Any]]] | None):
|
461
|
+
Metrics for the validation losses.
|
462
|
+
list -> list -> tuples
|
463
|
+
Epochs -> Validation Batches -> (loss, metrics)
|
464
|
+
"""
|
465
|
+
pass
|
466
|
+
|
467
|
+
def on_train_epoch_start(self) -> None:
|
468
|
+
"""Called at the start of each training epoch."""
|
469
|
+
pass
|
470
|
+
|
471
|
+
def on_train_epoch_end(self, train_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None:
|
472
|
+
"""
|
473
|
+
Called at the end of each training epoch.
|
474
|
+
|
475
|
+
Args:
|
476
|
+
train_epoch_loss_metrics: Metrics for the training epoch losses.
|
477
|
+
list -> tuples
|
478
|
+
Training Batches -> (loss, metrics)
|
479
|
+
"""
|
480
|
+
pass
|
481
|
+
|
482
|
+
def on_val_epoch_start(self) -> None:
|
483
|
+
"""Called at the start of each validation epoch."""
|
484
|
+
pass
|
485
|
+
|
486
|
+
def on_val_epoch_end(self, val_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None:
|
487
|
+
"""
|
488
|
+
Called at the end of each validation epoch.
|
489
|
+
|
490
|
+
Args:
|
491
|
+
val_epoch_loss_metrics: Metrics for the validation epoch loss.
|
492
|
+
list -> tuples
|
493
|
+
Validation Batches -> (loss, metrics)
|
494
|
+
"""
|
495
|
+
pass
|
496
|
+
|
497
|
+
def on_train_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
|
498
|
+
"""
|
499
|
+
Called at the start of each training batch.
|
500
|
+
|
501
|
+
Args:
|
502
|
+
batch: A batch of data from the DataLoader. Typically a tuple containing
|
503
|
+
input tensors and corresponding target tensors.
|
504
|
+
"""
|
505
|
+
pass
|
506
|
+
|
507
|
+
def on_train_batch_end(self, train_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
|
508
|
+
"""
|
509
|
+
Called at the end of each training batch.
|
510
|
+
|
511
|
+
Args:
|
512
|
+
train_batch_loss_metrics: Metrics for the training batch loss.
|
513
|
+
tuple of (loss, metrics)
|
514
|
+
"""
|
515
|
+
pass
|
516
|
+
|
517
|
+
def on_val_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
|
518
|
+
"""
|
519
|
+
Called at the start of each validation batch.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
batch: A batch of data from the DataLoader. Typically a tuple containing
|
523
|
+
input tensors and corresponding target tensors.
|
524
|
+
"""
|
525
|
+
pass
|
526
|
+
|
527
|
+
def on_val_batch_end(self, val_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
|
528
|
+
"""
|
529
|
+
Called at the end of each validation batch.
|
530
|
+
|
531
|
+
Args:
|
532
|
+
val_batch_loss_metrics: Metrics for the validation batch loss.
|
533
|
+
tuple of (loss, metrics)
|
534
|
+
"""
|
535
|
+
pass
|
536
|
+
|
537
|
+
def on_test_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None:
|
538
|
+
"""
|
539
|
+
Called at the start of each testing batch.
|
540
|
+
|
541
|
+
Args:
|
542
|
+
batch: A batch of data from the DataLoader. Typically a tuple containing
|
543
|
+
input tensors and corresponding target tensors.
|
544
|
+
"""
|
545
|
+
pass
|
546
|
+
|
547
|
+
def on_test_batch_end(self, test_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None:
|
548
|
+
"""
|
549
|
+
Called at the end of each testing batch.
|
550
|
+
|
551
|
+
Args:
|
552
|
+
test_batch_loss_metrics: Metrics for the testing batch loss.
|
553
|
+
tuple of (loss, metrics)
|
554
|
+
"""
|
555
|
+
pass
|