qadence 1.8.0__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -1
- qadence/analog/parse_analog.py +1 -2
- qadence/backends/gpsr.py +8 -2
- qadence/backends/pulser/backend.py +7 -23
- qadence/backends/pyqtorch/backend.py +80 -5
- qadence/backends/pyqtorch/config.py +10 -3
- qadence/backends/pyqtorch/convert_ops.py +63 -2
- qadence/blocks/primitive.py +1 -0
- qadence/execution.py +0 -2
- qadence/log_config.yaml +10 -0
- qadence/measurements/shadow.py +97 -128
- qadence/measurements/utils.py +2 -2
- qadence/mitigations/readout.py +12 -6
- qadence/ml_tools/__init__.py +4 -8
- qadence/ml_tools/callbacks/__init__.py +30 -0
- qadence/ml_tools/callbacks/callback.py +451 -0
- qadence/ml_tools/callbacks/callbackmanager.py +214 -0
- qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
- qadence/ml_tools/callbacks/writer_registry.py +441 -0
- qadence/ml_tools/config.py +132 -258
- qadence/ml_tools/data.py +7 -3
- qadence/ml_tools/loss/__init__.py +10 -0
- qadence/ml_tools/loss/loss.py +87 -0
- qadence/ml_tools/optimize_step.py +45 -10
- qadence/ml_tools/stages.py +46 -0
- qadence/ml_tools/train_utils/__init__.py +7 -0
- qadence/ml_tools/train_utils/base_trainer.py +555 -0
- qadence/ml_tools/train_utils/config_manager.py +184 -0
- qadence/ml_tools/trainer.py +708 -0
- qadence/model.py +1 -1
- qadence/noise/__init__.py +2 -2
- qadence/noise/protocols.py +18 -53
- qadence/operations/ham_evo.py +87 -26
- qadence/transpile/noise.py +12 -5
- qadence/types.py +15 -3
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/METADATA +3 -4
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/RECORD +39 -32
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/WHEEL +1 -1
- qadence/ml_tools/printing.py +0 -154
- qadence/ml_tools/train_grad.py +0 -395
- qadence/ml_tools/train_no_grad.py +0 -199
- qadence/noise/readout.py +0 -218
- {qadence-1.8.0.dist-info → qadence-1.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,451 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable
|
4
|
+
|
5
|
+
from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint
|
6
|
+
from qadence.ml_tools.callbacks.writer_registry import BaseWriter
|
7
|
+
from qadence.ml_tools.config import TrainConfig
|
8
|
+
from qadence.ml_tools.data import OptimizeResult
|
9
|
+
from qadence.ml_tools.stages import TrainingStage
|
10
|
+
|
11
|
+
# Define callback types
|
12
|
+
CallbackFunction = Callable[..., Any]
|
13
|
+
CallbackConditionFunction = Callable[..., bool]
|
14
|
+
|
15
|
+
|
16
|
+
class Callback:
|
17
|
+
"""Base class for defining various training callbacks.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
on (str): The event on which to trigger the callback.
|
21
|
+
Must be a valid on value from: ["train_start", "train_end",
|
22
|
+
"train_epoch_start", "train_epoch_end", "train_batch_start",
|
23
|
+
"train_batch_end","val_epoch_start", "val_epoch_end",
|
24
|
+
"val_batch_start", "val_batch_end", "test_batch_start",
|
25
|
+
"test_batch_end"]
|
26
|
+
called_every (int): Frequency of callback calls in terms of iterations.
|
27
|
+
callback (CallbackFunction | None): The function to call if the condition is met.
|
28
|
+
callback_condition (CallbackConditionFunction | None): Condition to check before calling.
|
29
|
+
modify_optimize_result (CallbackFunction | dict[str, Any] | None):
|
30
|
+
Function to modify `OptimizeResult`.
|
31
|
+
|
32
|
+
A callback can be defined in two ways:
|
33
|
+
|
34
|
+
1. **By providing a callback function directly in the base class**:
|
35
|
+
This is useful for simple callbacks that don't require subclassing.
|
36
|
+
|
37
|
+
Example:
|
38
|
+
```python exec="on" source="material-block" result="json"
|
39
|
+
from qadence.ml_tools.callbacks import Callback
|
40
|
+
|
41
|
+
def custom_callback_function(trainer, config, writer):
|
42
|
+
print("Custom callback executed.")
|
43
|
+
|
44
|
+
custom_callback = Callback(
|
45
|
+
on="train_end",
|
46
|
+
called_every=5,
|
47
|
+
callback=custom_callback_function
|
48
|
+
)
|
49
|
+
```
|
50
|
+
|
51
|
+
2. **By inheriting and implementing the `run_callback` method**:
|
52
|
+
This is suitable for more complex callbacks that require customization.
|
53
|
+
|
54
|
+
Example:
|
55
|
+
```python exec="on" source="material-block" result="json"
|
56
|
+
from qadence.ml_tools.callbacks import Callback
|
57
|
+
class CustomCallback(Callback):
|
58
|
+
def run_callback(self, trainer, config, writer):
|
59
|
+
print("Custom behavior in the inherited run_callback method.")
|
60
|
+
|
61
|
+
custom_callback = CustomCallback(on="train_end", called_every=10)
|
62
|
+
```
|
63
|
+
"""
|
64
|
+
|
65
|
+
VALID_ON_VALUES = [
|
66
|
+
"train_start",
|
67
|
+
"train_end",
|
68
|
+
"train_epoch_start",
|
69
|
+
"train_epoch_end",
|
70
|
+
"train_batch_start",
|
71
|
+
"train_batch_end",
|
72
|
+
"val_epoch_start",
|
73
|
+
"val_epoch_end",
|
74
|
+
"val_batch_start",
|
75
|
+
"val_batch_end",
|
76
|
+
"test_batch_start",
|
77
|
+
"test_batch_end",
|
78
|
+
]
|
79
|
+
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
on: str | TrainingStage = "idle",
|
83
|
+
called_every: int = 1,
|
84
|
+
callback: CallbackFunction | None = None,
|
85
|
+
callback_condition: CallbackConditionFunction | None = None,
|
86
|
+
modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
|
87
|
+
):
|
88
|
+
if not isinstance(called_every, int):
|
89
|
+
raise ValueError("called_every must be a positive integer or 0")
|
90
|
+
|
91
|
+
self.callback: CallbackFunction | None = callback
|
92
|
+
self.on: str | TrainingStage = on
|
93
|
+
self.called_every: int = called_every
|
94
|
+
self.callback_condition = callback_condition or (lambda _: True)
|
95
|
+
|
96
|
+
if isinstance(modify_optimize_result, dict):
|
97
|
+
self.modify_optimize_result = (
|
98
|
+
lambda opt_res: opt_res.extra.update(modify_optimize_result) or opt_res
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
self.modify_optimize_result = modify_optimize_result or (lambda opt_res: opt_res)
|
102
|
+
|
103
|
+
@property
|
104
|
+
def on(self) -> TrainingStage | str:
|
105
|
+
"""
|
106
|
+
Returns the TrainingStage.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
TrainingStage: TrainingStage for the callback
|
110
|
+
"""
|
111
|
+
return self._on
|
112
|
+
|
113
|
+
@on.setter
|
114
|
+
def on(self, on: str | TrainingStage) -> None:
|
115
|
+
"""
|
116
|
+
Sets the training stage on for the callback.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
on (str | TrainingStage): TrainingStage for the callback
|
120
|
+
"""
|
121
|
+
if isinstance(on, str):
|
122
|
+
if on not in self.VALID_ON_VALUES:
|
123
|
+
raise ValueError(f"Invalid value for 'on'. Must be one of {self.VALID_ON_VALUES}.")
|
124
|
+
self._on = TrainingStage(on)
|
125
|
+
elif isinstance(on, TrainingStage):
|
126
|
+
self._on = on
|
127
|
+
else:
|
128
|
+
raise ValueError("Invalid value for 'on'. Must be `str` or `TrainingStage`.")
|
129
|
+
|
130
|
+
def _should_call(self, when: str, opt_result: OptimizeResult) -> bool:
|
131
|
+
"""Checks if the callback should be called.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
when (str): The event when the callback is considered for execution.
|
135
|
+
opt_result (OptimizeResult): The current optimization results.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
bool: Whether the callback should be called.
|
139
|
+
"""
|
140
|
+
if when in [TrainingStage("train_start"), TrainingStage("train_end")]:
|
141
|
+
return True
|
142
|
+
if self.called_every == 0 or opt_result.iteration == 0:
|
143
|
+
return False
|
144
|
+
if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result):
|
145
|
+
return True
|
146
|
+
return False
|
147
|
+
|
148
|
+
def __call__(
|
149
|
+
self, when: TrainingStage, trainer: Any, config: TrainConfig, writer: BaseWriter
|
150
|
+
) -> Any:
|
151
|
+
"""Executes the callback if conditions are met.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
when (str): The event when the callback is triggered.
|
155
|
+
trainer (Any): The training object.
|
156
|
+
config (TrainConfig): The configuration object.
|
157
|
+
writer (BaseWriter ): The writer object for logging.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Any: Result of the callback function if executed.
|
161
|
+
"""
|
162
|
+
opt_result = trainer.opt_result
|
163
|
+
if self.on == when:
|
164
|
+
if opt_result:
|
165
|
+
opt_result = self.modify_optimize_result(opt_result)
|
166
|
+
if self._should_call(when, opt_result):
|
167
|
+
return self.run_callback(trainer, config, writer)
|
168
|
+
|
169
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
170
|
+
"""Executes the defined callback.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
trainer (Any): The training object.
|
174
|
+
config (TrainConfig): The configuration object.
|
175
|
+
writer (BaseWriter ): The writer object for logging.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
Any: Result of the callback execution.
|
179
|
+
|
180
|
+
Raises:
|
181
|
+
NotImplementedError: If not implemented in subclasses.
|
182
|
+
"""
|
183
|
+
if self.callback is not None:
|
184
|
+
return self.callback(trainer, config, writer)
|
185
|
+
raise NotImplementedError("Subclasses should override the run_callback method.")
|
186
|
+
|
187
|
+
|
188
|
+
class PrintMetrics(Callback):
|
189
|
+
"""Callback to print metrics using the writer.
|
190
|
+
|
191
|
+
The `PrintMetrics` callback can be added to the `TrainConfig`
|
192
|
+
callbacks as a custom user defined callback.
|
193
|
+
|
194
|
+
Example Usage in `TrainConfig`:
|
195
|
+
To use `PrintMetrics`, include it in the `callbacks` list when
|
196
|
+
setting up your `TrainConfig`:
|
197
|
+
```python exec="on" source="material-block" result="json"
|
198
|
+
from qadence.ml_tools import TrainConfig
|
199
|
+
from qadence.ml_tools.callbacks import PrintMetrics
|
200
|
+
|
201
|
+
# Create an instance of the PrintMetrics callback
|
202
|
+
print_metrics_callback = PrintMetrics(on = "val_batch_end", called_every = 100)
|
203
|
+
|
204
|
+
config = TrainConfig(
|
205
|
+
max_iter=10000,
|
206
|
+
# Print metrics every 1000 training epochs
|
207
|
+
print_every=1000,
|
208
|
+
# Add the custom callback that runs every 100 val_batch_end
|
209
|
+
callbacks=[print_metrics_callback]
|
210
|
+
)
|
211
|
+
```
|
212
|
+
"""
|
213
|
+
|
214
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
215
|
+
"""Prints metrics using the writer.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
trainer (Any): The training object.
|
219
|
+
config (TrainConfig): The configuration object.
|
220
|
+
writer (BaseWriter ): The writer object for logging.
|
221
|
+
"""
|
222
|
+
opt_result = trainer.opt_result
|
223
|
+
writer.print_metrics(opt_result)
|
224
|
+
|
225
|
+
|
226
|
+
class WriteMetrics(Callback):
|
227
|
+
"""Callback to write metrics using the writer.
|
228
|
+
|
229
|
+
The `WriteMetrics` callback can be added to the `TrainConfig` callbacks as
|
230
|
+
a custom user defined callback.
|
231
|
+
|
232
|
+
Example Usage in `TrainConfig`:
|
233
|
+
To use `WriteMetrics`, include it in the `callbacks` list when setting up your
|
234
|
+
`TrainConfig`:
|
235
|
+
```python exec="on" source="material-block" result="json"
|
236
|
+
from qadence.ml_tools import TrainConfig
|
237
|
+
from qadence.ml_tools.callbacks import WriteMetrics
|
238
|
+
|
239
|
+
# Create an instance of the WriteMetrics callback
|
240
|
+
write_metrics_callback = WriteMetrics(on = "val_batch_end", called_every = 100)
|
241
|
+
|
242
|
+
config = TrainConfig(
|
243
|
+
max_iter=10000,
|
244
|
+
# Print metrics every 1000 training epochs
|
245
|
+
print_every=1000,
|
246
|
+
# Add the custom callback that runs every 100 val_batch_end
|
247
|
+
callbacks=[write_metrics_callback]
|
248
|
+
)
|
249
|
+
```
|
250
|
+
"""
|
251
|
+
|
252
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
253
|
+
"""Writes metrics using the writer.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
trainer (Any): The training object.
|
257
|
+
config (TrainConfig): The configuration object.
|
258
|
+
writer (BaseWriter ): The writer object for logging.
|
259
|
+
"""
|
260
|
+
opt_result = trainer.opt_result
|
261
|
+
writer.write(opt_result)
|
262
|
+
|
263
|
+
|
264
|
+
class PlotMetrics(Callback):
|
265
|
+
"""Callback to plot metrics using the writer.
|
266
|
+
|
267
|
+
The `PlotMetrics` callback can be added to the `TrainConfig` callbacks as
|
268
|
+
a custom user defined callback.
|
269
|
+
|
270
|
+
Example Usage in `TrainConfig`:
|
271
|
+
To use `PlotMetrics`, include it in the `callbacks` list when setting up your
|
272
|
+
`TrainConfig`:
|
273
|
+
```python exec="on" source="material-block" result="json"
|
274
|
+
from qadence.ml_tools import TrainConfig
|
275
|
+
from qadence.ml_tools.callbacks import PlotMetrics
|
276
|
+
|
277
|
+
# Create an instance of the PlotMetrics callback
|
278
|
+
plot_metrics_callback = PlotMetrics(on = "val_batch_end", called_every = 100)
|
279
|
+
|
280
|
+
config = TrainConfig(
|
281
|
+
max_iter=10000,
|
282
|
+
# Print metrics every 1000 training epochs
|
283
|
+
print_every=1000,
|
284
|
+
# Add the custom callback that runs every 100 val_batch_end
|
285
|
+
callbacks=[plot_metrics_callback]
|
286
|
+
)
|
287
|
+
```
|
288
|
+
"""
|
289
|
+
|
290
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
291
|
+
"""Plots metrics using the writer.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
trainer (Any): The training object.
|
295
|
+
config (TrainConfig): The configuration object.
|
296
|
+
writer (BaseWriter ): The writer object for logging.
|
297
|
+
"""
|
298
|
+
opt_result = trainer.opt_result
|
299
|
+
plotting_functions = config.plotting_functions
|
300
|
+
writer.plot(trainer.model, opt_result.iteration, plotting_functions)
|
301
|
+
|
302
|
+
|
303
|
+
class LogHyperparameters(Callback):
|
304
|
+
"""Callback to log hyperparameters using the writer.
|
305
|
+
|
306
|
+
The `LogHyperparameters` callback can be added to the `TrainConfig` callbacks
|
307
|
+
as a custom user defined callback.
|
308
|
+
|
309
|
+
Example Usage in `TrainConfig`:
|
310
|
+
To use `LogHyperparameters`, include it in the `callbacks` list when setting up your
|
311
|
+
`TrainConfig`:
|
312
|
+
```python exec="on" source="material-block" result="json"
|
313
|
+
from qadence.ml_tools import TrainConfig
|
314
|
+
from qadence.ml_tools.callbacks import LogHyperparameters
|
315
|
+
|
316
|
+
# Create an instance of the LogHyperparameters callback
|
317
|
+
log_hyper_callback = LogHyperparameters(on = "val_batch_end", called_every = 100)
|
318
|
+
|
319
|
+
config = TrainConfig(
|
320
|
+
max_iter=10000,
|
321
|
+
# Print metrics every 1000 training epochs
|
322
|
+
print_every=1000,
|
323
|
+
# Add the custom callback that runs every 100 val_batch_end
|
324
|
+
callbacks=[log_hyper_callback]
|
325
|
+
)
|
326
|
+
```
|
327
|
+
"""
|
328
|
+
|
329
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
330
|
+
"""Logs hyperparameters using the writer.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
trainer (Any): The training object.
|
334
|
+
config (TrainConfig): The configuration object.
|
335
|
+
writer (BaseWriter ): The writer object for logging.
|
336
|
+
"""
|
337
|
+
hyperparams = config.hyperparams
|
338
|
+
writer.log_hyperparams(hyperparams)
|
339
|
+
|
340
|
+
|
341
|
+
class SaveCheckpoint(Callback):
|
342
|
+
"""Callback to save a model checkpoint.
|
343
|
+
|
344
|
+
The `SaveCheckpoint` callback can be added to the `TrainConfig` callbacks
|
345
|
+
as a custom user defined callback.
|
346
|
+
|
347
|
+
Example Usage in `TrainConfig`:
|
348
|
+
To use `SaveCheckpoint`, include it in the `callbacks` list when setting up your
|
349
|
+
`TrainConfig`:
|
350
|
+
```python exec="on" source="material-block" result="json"
|
351
|
+
from qadence.ml_tools import TrainConfig
|
352
|
+
from qadence.ml_tools.callbacks import SaveCheckpoint
|
353
|
+
|
354
|
+
# Create an instance of the SaveCheckpoint callback
|
355
|
+
save_checkpoint_callback = SaveCheckpoint(on = "val_batch_end", called_every = 100)
|
356
|
+
|
357
|
+
config = TrainConfig(
|
358
|
+
max_iter=10000,
|
359
|
+
# Print metrics every 1000 training epochs
|
360
|
+
print_every=1000,
|
361
|
+
# Add the custom callback that runs every 100 val_batch_end
|
362
|
+
callbacks=[save_checkpoint_callback]
|
363
|
+
)
|
364
|
+
```
|
365
|
+
"""
|
366
|
+
|
367
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
368
|
+
"""Saves a model checkpoint.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
trainer (Any): The training object.
|
372
|
+
config (TrainConfig): The configuration object.
|
373
|
+
writer (BaseWriter ): The writer object for logging.
|
374
|
+
"""
|
375
|
+
folder = config.log_folder
|
376
|
+
model = trainer.model
|
377
|
+
optimizer = trainer.optimizer
|
378
|
+
opt_result = trainer.opt_result
|
379
|
+
write_checkpoint(folder, model, optimizer, opt_result.iteration)
|
380
|
+
|
381
|
+
|
382
|
+
class SaveBestCheckpoint(SaveCheckpoint):
|
383
|
+
"""Callback to save the best model checkpoint based on a validation criterion."""
|
384
|
+
|
385
|
+
def __init__(self, on: str, called_every: int):
|
386
|
+
"""Initializes the SaveBestCheckpoint callback.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
on (str): The event to trigger the callback.
|
390
|
+
called_every (int): Frequency of callback calls in terms of iterations.
|
391
|
+
"""
|
392
|
+
super().__init__(on=on, called_every=called_every)
|
393
|
+
self.best_loss = float("inf")
|
394
|
+
|
395
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
396
|
+
"""Saves the checkpoint if the current loss is better than the best loss.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
trainer (Any): The training object.
|
400
|
+
config (TrainConfig): The configuration object.
|
401
|
+
writer (BaseWriter ): The writer object for logging.
|
402
|
+
"""
|
403
|
+
opt_result = trainer.opt_result
|
404
|
+
if config.validation_criterion and config.validation_criterion(
|
405
|
+
opt_result.loss, self.best_loss, config.val_epsilon
|
406
|
+
):
|
407
|
+
self.best_loss = opt_result.loss
|
408
|
+
|
409
|
+
folder = config.log_folder
|
410
|
+
model = trainer.model
|
411
|
+
optimizer = trainer.optimizer
|
412
|
+
opt_result = trainer.opt_result
|
413
|
+
write_checkpoint(folder, model, optimizer, "best")
|
414
|
+
|
415
|
+
|
416
|
+
class LoadCheckpoint(Callback):
|
417
|
+
"""Callback to load a model checkpoint."""
|
418
|
+
|
419
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
420
|
+
"""Loads a model checkpoint.
|
421
|
+
|
422
|
+
Args:
|
423
|
+
trainer (Any): The training object.
|
424
|
+
config (TrainConfig): The configuration object.
|
425
|
+
writer (BaseWriter ): The writer object for logging.
|
426
|
+
|
427
|
+
Returns:
|
428
|
+
Any: The result of loading the checkpoint.
|
429
|
+
"""
|
430
|
+
folder = config.log_folder
|
431
|
+
model = trainer.model
|
432
|
+
optimizer = trainer.optimizer
|
433
|
+
device = trainer.log_device
|
434
|
+
return load_checkpoint(folder, model, optimizer, device=device)
|
435
|
+
|
436
|
+
|
437
|
+
class LogModelTracker(Callback):
|
438
|
+
"""Callback to log the model using the writer."""
|
439
|
+
|
440
|
+
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
|
441
|
+
"""Logs the model using the writer.
|
442
|
+
|
443
|
+
Args:
|
444
|
+
trainer (Any): The training object.
|
445
|
+
config (TrainConfig): The configuration object.
|
446
|
+
writer (BaseWriter ): The writer object for logging.
|
447
|
+
"""
|
448
|
+
model = trainer.model
|
449
|
+
writer.log_model(
|
450
|
+
model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
|
451
|
+
)
|
@@ -0,0 +1,214 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import logging
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from qadence.ml_tools.callbacks.callback import (
|
8
|
+
Callback,
|
9
|
+
LoadCheckpoint,
|
10
|
+
LogHyperparameters,
|
11
|
+
LogModelTracker,
|
12
|
+
PlotMetrics,
|
13
|
+
PrintMetrics,
|
14
|
+
SaveBestCheckpoint,
|
15
|
+
SaveCheckpoint,
|
16
|
+
WriteMetrics,
|
17
|
+
)
|
18
|
+
from qadence.ml_tools.config import TrainConfig
|
19
|
+
from qadence.ml_tools.data import OptimizeResult
|
20
|
+
from qadence.ml_tools.stages import TrainingStage
|
21
|
+
|
22
|
+
from .writer_registry import get_writer
|
23
|
+
|
24
|
+
logger = logging.getLogger("ml_tools")
|
25
|
+
|
26
|
+
|
27
|
+
class CallbacksManager:
|
28
|
+
"""Manages and orchestrates the execution of various training callbacks.
|
29
|
+
|
30
|
+
Provides the start training and end training methods.
|
31
|
+
|
32
|
+
Attributes:
|
33
|
+
use_grad (bool): Indicates whether to use gradients in callbacks.
|
34
|
+
config (TrainConfig): The training configuration object.
|
35
|
+
callbacks (List[Callback]): List of callback instances to be executed.
|
36
|
+
writer (Optional[BaseWriter]): The writer instance for logging metrics and information.
|
37
|
+
"""
|
38
|
+
|
39
|
+
use_grad: bool = True
|
40
|
+
|
41
|
+
callback_map = {
|
42
|
+
"PrintMetrics": PrintMetrics,
|
43
|
+
"WriteMetrics": WriteMetrics,
|
44
|
+
"PlotMetrics": PlotMetrics,
|
45
|
+
"SaveCheckpoint": SaveCheckpoint,
|
46
|
+
"LoadCheckpoint": LoadCheckpoint,
|
47
|
+
"LogModelTracker": LogModelTracker,
|
48
|
+
"LogHyperparameters": LogHyperparameters,
|
49
|
+
"SaveBestCheckpoint": SaveBestCheckpoint,
|
50
|
+
}
|
51
|
+
|
52
|
+
def __init__(self, config: TrainConfig):
|
53
|
+
"""
|
54
|
+
Initializes the CallbacksManager with a training configuration.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
config (TrainConfig): The training configuration object.
|
58
|
+
"""
|
59
|
+
self.config = config
|
60
|
+
tracking_tool = self.config.tracking_tool
|
61
|
+
self.writer = get_writer(tracking_tool)
|
62
|
+
self.callbacks: list[Callback] = []
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def set_use_grad(cls, use_grad: bool) -> None:
|
66
|
+
"""
|
67
|
+
Sets whether gradients should be used in callbacks.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
use_grad (bool): A boolean indicating whether to use gradients.
|
71
|
+
"""
|
72
|
+
if not isinstance(use_grad, bool):
|
73
|
+
raise ValueError("use_grad must be a boolean value.")
|
74
|
+
cls.use_grad = use_grad
|
75
|
+
|
76
|
+
def initialize_callbacks(self) -> None:
|
77
|
+
"""Initializes and adds the necessary callbacks based on the configuration."""
|
78
|
+
# Train Start
|
79
|
+
self.callbacks = copy.deepcopy(self.config.callbacks)
|
80
|
+
self.add_callback("PlotMetrics", "train_start")
|
81
|
+
if self.config.val_every:
|
82
|
+
self.add_callback("WriteMetrics", "train_start")
|
83
|
+
# only save the first checkpoint if not checkpoint_best_only
|
84
|
+
if not self.config.checkpoint_best_only:
|
85
|
+
self.add_callback("SaveCheckpoint", "train_start")
|
86
|
+
|
87
|
+
# Checkpointing
|
88
|
+
if self.config.checkpoint_best_only:
|
89
|
+
self.add_callback("SaveBestCheckpoint", "val_epoch_end", self.config.val_every)
|
90
|
+
elif self.config.checkpoint_every:
|
91
|
+
self.add_callback("SaveCheckpoint", "train_epoch_end", self.config.checkpoint_every)
|
92
|
+
|
93
|
+
# Printing
|
94
|
+
if self.config.verbose and self.config.print_every:
|
95
|
+
self.add_callback("PrintMetrics", "train_epoch_end", self.config.print_every)
|
96
|
+
|
97
|
+
# Plotting
|
98
|
+
if self.config.plot_every:
|
99
|
+
self.add_callback("PlotMetrics", "train_epoch_end", self.config.plot_every)
|
100
|
+
|
101
|
+
# Writing
|
102
|
+
if self.config.write_every:
|
103
|
+
self.add_callback("WriteMetrics", "train_epoch_end", self.config.write_every)
|
104
|
+
if self.config.val_every:
|
105
|
+
self.add_callback("WriteMetrics", "val_epoch_end", self.config.val_every)
|
106
|
+
|
107
|
+
# Train End
|
108
|
+
# Hyperparameters
|
109
|
+
if self.config.hyperparams:
|
110
|
+
self.add_callback("LogHyperparameters", "train_end")
|
111
|
+
# Log model
|
112
|
+
if self.config.log_model:
|
113
|
+
self.add_callback("LogModelTracker", "train_end")
|
114
|
+
if self.config.plot_every:
|
115
|
+
self.add_callback("PlotMetrics", "train_end")
|
116
|
+
# only save the last checkpoint if not checkpoint_best_only
|
117
|
+
if not self.config.checkpoint_best_only:
|
118
|
+
self.add_callback("SaveCheckpoint", "train_end")
|
119
|
+
self.add_callback("WriteMetrics", "train_end")
|
120
|
+
|
121
|
+
def add_callback(
|
122
|
+
self, callback: str | Callback, on: str | TrainingStage, called_every: int = 1
|
123
|
+
) -> None:
|
124
|
+
"""
|
125
|
+
Adds a callback to the manager.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
callback (str | Callback): The callback instance or name.
|
129
|
+
on (str | TrainingStage): The event on which to trigger the callback.
|
130
|
+
called_every (int): Frequency of callback calls in terms of iterations.
|
131
|
+
"""
|
132
|
+
if isinstance(callback, str):
|
133
|
+
callback_class = self.callback_map.get(callback)
|
134
|
+
if callback_class:
|
135
|
+
# Create an instance of the callback class
|
136
|
+
callback_instance = callback_class(on=on, called_every=called_every)
|
137
|
+
self.callbacks.append(callback_instance)
|
138
|
+
else:
|
139
|
+
logger.warning(f"Callback '{callback}' not recognized and will be skipped.")
|
140
|
+
elif isinstance(callback, Callback):
|
141
|
+
callback.on = on
|
142
|
+
callback.called_every = called_every
|
143
|
+
self.callbacks.append(callback)
|
144
|
+
else:
|
145
|
+
logger.warning(
|
146
|
+
f"Invalid callback type: {type(callback)}. Expected str or Callback instance."
|
147
|
+
)
|
148
|
+
|
149
|
+
def run_callbacks(self, trainer: Any) -> Any:
|
150
|
+
"""
|
151
|
+
Runs callbacks that match the current training state.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
trainer (Any): The training object managing the training process.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
Any: Results of the executed callbacks.
|
158
|
+
"""
|
159
|
+
return [
|
160
|
+
callback(
|
161
|
+
when=trainer.training_stage, trainer=trainer, config=self.config, writer=self.writer
|
162
|
+
)
|
163
|
+
for callback in self.callbacks
|
164
|
+
if callback.on == trainer.training_stage
|
165
|
+
]
|
166
|
+
|
167
|
+
def start_training(self, trainer: Any) -> None:
|
168
|
+
"""
|
169
|
+
Initializes callbacks and starts the training process.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
trainer (Any): The training object managing the training process.
|
173
|
+
"""
|
174
|
+
# Clear all handlers from the logger
|
175
|
+
self.initialize_callbacks()
|
176
|
+
|
177
|
+
trainer.opt_result = OptimizeResult(trainer.global_step, trainer.model, trainer.optimizer)
|
178
|
+
trainer.is_last_iteration = False
|
179
|
+
|
180
|
+
# Load checkpoint only if a new subfolder was NOT recently added
|
181
|
+
if not trainer.config_manager._added_new_subfolder:
|
182
|
+
load_checkpoint_callback = LoadCheckpoint(on="train_start", called_every=1)
|
183
|
+
loaded_result = load_checkpoint_callback.run_callback(
|
184
|
+
trainer=trainer,
|
185
|
+
config=self.config,
|
186
|
+
writer=None, # type: ignore[arg-type]
|
187
|
+
)
|
188
|
+
|
189
|
+
if loaded_result:
|
190
|
+
model, optimizer, init_iter = loaded_result
|
191
|
+
if isinstance(init_iter, (int, str)):
|
192
|
+
trainer.model = model
|
193
|
+
trainer.optimizer = optimizer
|
194
|
+
trainer.global_step = (
|
195
|
+
init_iter if isinstance(init_iter, int) else trainer.global_step
|
196
|
+
)
|
197
|
+
trainer.current_epoch = (
|
198
|
+
init_iter if isinstance(init_iter, int) else trainer.current_epoch
|
199
|
+
)
|
200
|
+
trainer.opt_result = OptimizeResult(trainer.current_epoch, model, optimizer)
|
201
|
+
logger.debug(f"Loaded model and optimizer from {self.config.log_folder}")
|
202
|
+
|
203
|
+
# Setup writer
|
204
|
+
self.writer.open(self.config, iteration=trainer.global_step)
|
205
|
+
|
206
|
+
def end_training(self, trainer: Any) -> None:
|
207
|
+
"""
|
208
|
+
Cleans up and finalizes the training process.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
trainer (Any): The training object managing the training process.
|
212
|
+
"""
|
213
|
+
if self.writer:
|
214
|
+
self.writer.close()
|