qadence 1.8.0__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. qadence/__init__.py +1 -1
  2. qadence/analog/parse_analog.py +1 -2
  3. qadence/backends/gpsr.py +8 -2
  4. qadence/backends/pulser/backend.py +7 -23
  5. qadence/backends/pyqtorch/backend.py +80 -5
  6. qadence/backends/pyqtorch/config.py +10 -3
  7. qadence/backends/pyqtorch/convert_ops.py +63 -2
  8. qadence/blocks/primitive.py +1 -0
  9. qadence/execution.py +0 -2
  10. qadence/log_config.yaml +10 -0
  11. qadence/measurements/shadow.py +97 -128
  12. qadence/measurements/utils.py +2 -2
  13. qadence/mitigations/readout.py +12 -6
  14. qadence/ml_tools/__init__.py +4 -8
  15. qadence/ml_tools/callbacks/__init__.py +30 -0
  16. qadence/ml_tools/callbacks/callback.py +451 -0
  17. qadence/ml_tools/callbacks/callbackmanager.py +214 -0
  18. qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
  19. qadence/ml_tools/callbacks/writer_registry.py +430 -0
  20. qadence/ml_tools/config.py +132 -258
  21. qadence/ml_tools/data.py +7 -3
  22. qadence/ml_tools/loss/__init__.py +10 -0
  23. qadence/ml_tools/loss/loss.py +87 -0
  24. qadence/ml_tools/optimize_step.py +45 -10
  25. qadence/ml_tools/stages.py +46 -0
  26. qadence/ml_tools/train_utils/__init__.py +7 -0
  27. qadence/ml_tools/train_utils/base_trainer.py +548 -0
  28. qadence/ml_tools/train_utils/config_manager.py +184 -0
  29. qadence/ml_tools/trainer.py +692 -0
  30. qadence/model.py +1 -1
  31. qadence/noise/__init__.py +2 -2
  32. qadence/noise/protocols.py +18 -53
  33. qadence/operations/ham_evo.py +87 -26
  34. qadence/transpile/noise.py +12 -5
  35. qadence/types.py +15 -3
  36. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/METADATA +3 -4
  37. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/RECORD +39 -32
  38. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/WHEEL +1 -1
  39. qadence/ml_tools/printing.py +0 -154
  40. qadence/ml_tools/train_grad.py +0 -395
  41. qadence/ml_tools/train_no_grad.py +0 -199
  42. qadence/noise/readout.py +0 -218
  43. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,451 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable
4
+
5
+ from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint
6
+ from qadence.ml_tools.callbacks.writer_registry import BaseWriter
7
+ from qadence.ml_tools.config import TrainConfig
8
+ from qadence.ml_tools.data import OptimizeResult
9
+ from qadence.ml_tools.stages import TrainingStage
10
+
11
+ # Define callback types
12
+ CallbackFunction = Callable[..., Any]
13
+ CallbackConditionFunction = Callable[..., bool]
14
+
15
+
16
+ class Callback:
17
+ """Base class for defining various training callbacks.
18
+
19
+ Attributes:
20
+ on (str): The event on which to trigger the callback.
21
+ Must be a valid on value from: ["train_start", "train_end",
22
+ "train_epoch_start", "train_epoch_end", "train_batch_start",
23
+ "train_batch_end","val_epoch_start", "val_epoch_end",
24
+ "val_batch_start", "val_batch_end", "test_batch_start",
25
+ "test_batch_end"]
26
+ called_every (int): Frequency of callback calls in terms of iterations.
27
+ callback (CallbackFunction | None): The function to call if the condition is met.
28
+ callback_condition (CallbackConditionFunction | None): Condition to check before calling.
29
+ modify_optimize_result (CallbackFunction | dict[str, Any] | None):
30
+ Function to modify `OptimizeResult`.
31
+
32
+ A callback can be defined in two ways:
33
+
34
+ 1. **By providing a callback function directly in the base class**:
35
+ This is useful for simple callbacks that don't require subclassing.
36
+
37
+ Example:
38
+ ```python exec="on" source="material-block" result="json"
39
+ from qadence.ml_tools.callbacks import Callback
40
+
41
+ def custom_callback_function(trainer, config, writer):
42
+ print("Custom callback executed.")
43
+
44
+ custom_callback = Callback(
45
+ on="train_end",
46
+ called_every=5,
47
+ callback=custom_callback_function
48
+ )
49
+ ```
50
+
51
+ 2. **By inheriting and implementing the `run_callback` method**:
52
+ This is suitable for more complex callbacks that require customization.
53
+
54
+ Example:
55
+ ```python exec="on" source="material-block" result="json"
56
+ from qadence.ml_tools.callbacks import Callback
57
+ class CustomCallback(Callback):
58
+ def run_callback(self, trainer, config, writer):
59
+ print("Custom behavior in the inherited run_callback method.")
60
+
61
+ custom_callback = CustomCallback(on="train_end", called_every=10)
62
+ ```
63
+ """
64
+
65
+ VALID_ON_VALUES = [
66
+ "train_start",
67
+ "train_end",
68
+ "train_epoch_start",
69
+ "train_epoch_end",
70
+ "train_batch_start",
71
+ "train_batch_end",
72
+ "val_epoch_start",
73
+ "val_epoch_end",
74
+ "val_batch_start",
75
+ "val_batch_end",
76
+ "test_batch_start",
77
+ "test_batch_end",
78
+ ]
79
+
80
+ def __init__(
81
+ self,
82
+ on: str | TrainingStage = "idle",
83
+ called_every: int = 1,
84
+ callback: CallbackFunction | None = None,
85
+ callback_condition: CallbackConditionFunction | None = None,
86
+ modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
87
+ ):
88
+ if not isinstance(called_every, int):
89
+ raise ValueError("called_every must be a positive integer or 0")
90
+
91
+ self.callback: CallbackFunction | None = callback
92
+ self.on: str | TrainingStage = on
93
+ self.called_every: int = called_every
94
+ self.callback_condition = callback_condition or (lambda _: True)
95
+
96
+ if isinstance(modify_optimize_result, dict):
97
+ self.modify_optimize_result = (
98
+ lambda opt_res: opt_res.extra.update(modify_optimize_result) or opt_res
99
+ )
100
+ else:
101
+ self.modify_optimize_result = modify_optimize_result or (lambda opt_res: opt_res)
102
+
103
+ @property
104
+ def on(self) -> TrainingStage | str:
105
+ """
106
+ Returns the TrainingStage.
107
+
108
+ Returns:
109
+ TrainingStage: TrainingStage for the callback
110
+ """
111
+ return self._on
112
+
113
+ @on.setter
114
+ def on(self, on: str | TrainingStage) -> None:
115
+ """
116
+ Sets the training stage on for the callback.
117
+
118
+ Args:
119
+ on (str | TrainingStage): TrainingStage for the callback
120
+ """
121
+ if isinstance(on, str):
122
+ if on not in self.VALID_ON_VALUES:
123
+ raise ValueError(f"Invalid value for 'on'. Must be one of {self.VALID_ON_VALUES}.")
124
+ self._on = TrainingStage(on)
125
+ elif isinstance(on, TrainingStage):
126
+ self._on = on
127
+ else:
128
+ raise ValueError("Invalid value for 'on'. Must be `str` or `TrainingStage`.")
129
+
130
+ def _should_call(self, when: str, opt_result: OptimizeResult) -> bool:
131
+ """Checks if the callback should be called.
132
+
133
+ Args:
134
+ when (str): The event when the callback is considered for execution.
135
+ opt_result (OptimizeResult): The current optimization results.
136
+
137
+ Returns:
138
+ bool: Whether the callback should be called.
139
+ """
140
+ if when in [TrainingStage("train_start"), TrainingStage("train_end")]:
141
+ return True
142
+ if self.called_every == 0 or opt_result.iteration == 0:
143
+ return False
144
+ if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result):
145
+ return True
146
+ return False
147
+
148
+ def __call__(
149
+ self, when: TrainingStage, trainer: Any, config: TrainConfig, writer: BaseWriter
150
+ ) -> Any:
151
+ """Executes the callback if conditions are met.
152
+
153
+ Args:
154
+ when (str): The event when the callback is triggered.
155
+ trainer (Any): The training object.
156
+ config (TrainConfig): The configuration object.
157
+ writer (BaseWriter ): The writer object for logging.
158
+
159
+ Returns:
160
+ Any: Result of the callback function if executed.
161
+ """
162
+ opt_result = trainer.opt_result
163
+ if self.on == when:
164
+ if opt_result:
165
+ opt_result = self.modify_optimize_result(opt_result)
166
+ if self._should_call(when, opt_result):
167
+ return self.run_callback(trainer, config, writer)
168
+
169
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
170
+ """Executes the defined callback.
171
+
172
+ Args:
173
+ trainer (Any): The training object.
174
+ config (TrainConfig): The configuration object.
175
+ writer (BaseWriter ): The writer object for logging.
176
+
177
+ Returns:
178
+ Any: Result of the callback execution.
179
+
180
+ Raises:
181
+ NotImplementedError: If not implemented in subclasses.
182
+ """
183
+ if self.callback is not None:
184
+ return self.callback(trainer, config, writer)
185
+ raise NotImplementedError("Subclasses should override the run_callback method.")
186
+
187
+
188
+ class PrintMetrics(Callback):
189
+ """Callback to print metrics using the writer.
190
+
191
+ The `PrintMetrics` callback can be added to the `TrainConfig`
192
+ callbacks as a custom user defined callback.
193
+
194
+ Example Usage in `TrainConfig`:
195
+ To use `PrintMetrics`, include it in the `callbacks` list when
196
+ setting up your `TrainConfig`:
197
+ ```python exec="on" source="material-block" result="json"
198
+ from qadence.ml_tools import TrainConfig
199
+ from qadence.ml_tools.callbacks import PrintMetrics
200
+
201
+ # Create an instance of the PrintMetrics callback
202
+ print_metrics_callback = PrintMetrics(on = "val_batch_end", called_every = 100)
203
+
204
+ config = TrainConfig(
205
+ max_iter=10000,
206
+ # Print metrics every 1000 training epochs
207
+ print_every=1000,
208
+ # Add the custom callback that runs every 100 val_batch_end
209
+ callbacks=[print_metrics_callback]
210
+ )
211
+ ```
212
+ """
213
+
214
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
215
+ """Prints metrics using the writer.
216
+
217
+ Args:
218
+ trainer (Any): The training object.
219
+ config (TrainConfig): The configuration object.
220
+ writer (BaseWriter ): The writer object for logging.
221
+ """
222
+ opt_result = trainer.opt_result
223
+ writer.print_metrics(opt_result)
224
+
225
+
226
+ class WriteMetrics(Callback):
227
+ """Callback to write metrics using the writer.
228
+
229
+ The `WriteMetrics` callback can be added to the `TrainConfig` callbacks as
230
+ a custom user defined callback.
231
+
232
+ Example Usage in `TrainConfig`:
233
+ To use `WriteMetrics`, include it in the `callbacks` list when setting up your
234
+ `TrainConfig`:
235
+ ```python exec="on" source="material-block" result="json"
236
+ from qadence.ml_tools import TrainConfig
237
+ from qadence.ml_tools.callbacks import WriteMetrics
238
+
239
+ # Create an instance of the WriteMetrics callback
240
+ write_metrics_callback = WriteMetrics(on = "val_batch_end", called_every = 100)
241
+
242
+ config = TrainConfig(
243
+ max_iter=10000,
244
+ # Print metrics every 1000 training epochs
245
+ print_every=1000,
246
+ # Add the custom callback that runs every 100 val_batch_end
247
+ callbacks=[write_metrics_callback]
248
+ )
249
+ ```
250
+ """
251
+
252
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
253
+ """Writes metrics using the writer.
254
+
255
+ Args:
256
+ trainer (Any): The training object.
257
+ config (TrainConfig): The configuration object.
258
+ writer (BaseWriter ): The writer object for logging.
259
+ """
260
+ opt_result = trainer.opt_result
261
+ writer.write(opt_result)
262
+
263
+
264
+ class PlotMetrics(Callback):
265
+ """Callback to plot metrics using the writer.
266
+
267
+ The `PlotMetrics` callback can be added to the `TrainConfig` callbacks as
268
+ a custom user defined callback.
269
+
270
+ Example Usage in `TrainConfig`:
271
+ To use `PlotMetrics`, include it in the `callbacks` list when setting up your
272
+ `TrainConfig`:
273
+ ```python exec="on" source="material-block" result="json"
274
+ from qadence.ml_tools import TrainConfig
275
+ from qadence.ml_tools.callbacks import PlotMetrics
276
+
277
+ # Create an instance of the PlotMetrics callback
278
+ plot_metrics_callback = PlotMetrics(on = "val_batch_end", called_every = 100)
279
+
280
+ config = TrainConfig(
281
+ max_iter=10000,
282
+ # Print metrics every 1000 training epochs
283
+ print_every=1000,
284
+ # Add the custom callback that runs every 100 val_batch_end
285
+ callbacks=[plot_metrics_callback]
286
+ )
287
+ ```
288
+ """
289
+
290
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
291
+ """Plots metrics using the writer.
292
+
293
+ Args:
294
+ trainer (Any): The training object.
295
+ config (TrainConfig): The configuration object.
296
+ writer (BaseWriter ): The writer object for logging.
297
+ """
298
+ opt_result = trainer.opt_result
299
+ plotting_functions = config.plotting_functions
300
+ writer.plot(trainer.model, opt_result.iteration, plotting_functions)
301
+
302
+
303
+ class LogHyperparameters(Callback):
304
+ """Callback to log hyperparameters using the writer.
305
+
306
+ The `LogHyperparameters` callback can be added to the `TrainConfig` callbacks
307
+ as a custom user defined callback.
308
+
309
+ Example Usage in `TrainConfig`:
310
+ To use `LogHyperparameters`, include it in the `callbacks` list when setting up your
311
+ `TrainConfig`:
312
+ ```python exec="on" source="material-block" result="json"
313
+ from qadence.ml_tools import TrainConfig
314
+ from qadence.ml_tools.callbacks import LogHyperparameters
315
+
316
+ # Create an instance of the LogHyperparameters callback
317
+ log_hyper_callback = LogHyperparameters(on = "val_batch_end", called_every = 100)
318
+
319
+ config = TrainConfig(
320
+ max_iter=10000,
321
+ # Print metrics every 1000 training epochs
322
+ print_every=1000,
323
+ # Add the custom callback that runs every 100 val_batch_end
324
+ callbacks=[log_hyper_callback]
325
+ )
326
+ ```
327
+ """
328
+
329
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
330
+ """Logs hyperparameters using the writer.
331
+
332
+ Args:
333
+ trainer (Any): The training object.
334
+ config (TrainConfig): The configuration object.
335
+ writer (BaseWriter ): The writer object for logging.
336
+ """
337
+ hyperparams = config.hyperparams
338
+ writer.log_hyperparams(hyperparams)
339
+
340
+
341
+ class SaveCheckpoint(Callback):
342
+ """Callback to save a model checkpoint.
343
+
344
+ The `SaveCheckpoint` callback can be added to the `TrainConfig` callbacks
345
+ as a custom user defined callback.
346
+
347
+ Example Usage in `TrainConfig`:
348
+ To use `SaveCheckpoint`, include it in the `callbacks` list when setting up your
349
+ `TrainConfig`:
350
+ ```python exec="on" source="material-block" result="json"
351
+ from qadence.ml_tools import TrainConfig
352
+ from qadence.ml_tools.callbacks import SaveCheckpoint
353
+
354
+ # Create an instance of the SaveCheckpoint callback
355
+ save_checkpoint_callback = SaveCheckpoint(on = "val_batch_end", called_every = 100)
356
+
357
+ config = TrainConfig(
358
+ max_iter=10000,
359
+ # Print metrics every 1000 training epochs
360
+ print_every=1000,
361
+ # Add the custom callback that runs every 100 val_batch_end
362
+ callbacks=[save_checkpoint_callback]
363
+ )
364
+ ```
365
+ """
366
+
367
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
368
+ """Saves a model checkpoint.
369
+
370
+ Args:
371
+ trainer (Any): The training object.
372
+ config (TrainConfig): The configuration object.
373
+ writer (BaseWriter ): The writer object for logging.
374
+ """
375
+ folder = config.log_folder
376
+ model = trainer.model
377
+ optimizer = trainer.optimizer
378
+ opt_result = trainer.opt_result
379
+ write_checkpoint(folder, model, optimizer, opt_result.iteration)
380
+
381
+
382
+ class SaveBestCheckpoint(SaveCheckpoint):
383
+ """Callback to save the best model checkpoint based on a validation criterion."""
384
+
385
+ def __init__(self, on: str, called_every: int):
386
+ """Initializes the SaveBestCheckpoint callback.
387
+
388
+ Args:
389
+ on (str): The event to trigger the callback.
390
+ called_every (int): Frequency of callback calls in terms of iterations.
391
+ """
392
+ super().__init__(on=on, called_every=called_every)
393
+ self.best_loss = float("inf")
394
+
395
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
396
+ """Saves the checkpoint if the current loss is better than the best loss.
397
+
398
+ Args:
399
+ trainer (Any): The training object.
400
+ config (TrainConfig): The configuration object.
401
+ writer (BaseWriter ): The writer object for logging.
402
+ """
403
+ opt_result = trainer.opt_result
404
+ if config.validation_criterion and config.validation_criterion(
405
+ opt_result.loss, self.best_loss, config.val_epsilon
406
+ ):
407
+ self.best_loss = opt_result.loss
408
+
409
+ folder = config.log_folder
410
+ model = trainer.model
411
+ optimizer = trainer.optimizer
412
+ opt_result = trainer.opt_result
413
+ write_checkpoint(folder, model, optimizer, "best")
414
+
415
+
416
+ class LoadCheckpoint(Callback):
417
+ """Callback to load a model checkpoint."""
418
+
419
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
420
+ """Loads a model checkpoint.
421
+
422
+ Args:
423
+ trainer (Any): The training object.
424
+ config (TrainConfig): The configuration object.
425
+ writer (BaseWriter ): The writer object for logging.
426
+
427
+ Returns:
428
+ Any: The result of loading the checkpoint.
429
+ """
430
+ folder = config.log_folder
431
+ model = trainer.model
432
+ optimizer = trainer.optimizer
433
+ device = trainer.log_device
434
+ return load_checkpoint(folder, model, optimizer, device=device)
435
+
436
+
437
+ class LogModelTracker(Callback):
438
+ """Callback to log the model using the writer."""
439
+
440
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
441
+ """Logs the model using the writer.
442
+
443
+ Args:
444
+ trainer (Any): The training object.
445
+ config (TrainConfig): The configuration object.
446
+ writer (BaseWriter ): The writer object for logging.
447
+ """
448
+ model = trainer.model
449
+ writer.log_model(
450
+ model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
451
+ )
@@ -0,0 +1,214 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Any
6
+
7
+ from qadence.ml_tools.callbacks.callback import (
8
+ Callback,
9
+ LoadCheckpoint,
10
+ LogHyperparameters,
11
+ LogModelTracker,
12
+ PlotMetrics,
13
+ PrintMetrics,
14
+ SaveBestCheckpoint,
15
+ SaveCheckpoint,
16
+ WriteMetrics,
17
+ )
18
+ from qadence.ml_tools.config import TrainConfig
19
+ from qadence.ml_tools.data import OptimizeResult
20
+ from qadence.ml_tools.stages import TrainingStage
21
+
22
+ from .writer_registry import get_writer
23
+
24
+ logger = logging.getLogger("ml_tools")
25
+
26
+
27
+ class CallbacksManager:
28
+ """Manages and orchestrates the execution of various training callbacks.
29
+
30
+ Provides the start training and end training methods.
31
+
32
+ Attributes:
33
+ use_grad (bool): Indicates whether to use gradients in callbacks.
34
+ config (TrainConfig): The training configuration object.
35
+ callbacks (List[Callback]): List of callback instances to be executed.
36
+ writer (Optional[BaseWriter]): The writer instance for logging metrics and information.
37
+ """
38
+
39
+ use_grad: bool = True
40
+
41
+ callback_map = {
42
+ "PrintMetrics": PrintMetrics,
43
+ "WriteMetrics": WriteMetrics,
44
+ "PlotMetrics": PlotMetrics,
45
+ "SaveCheckpoint": SaveCheckpoint,
46
+ "LoadCheckpoint": LoadCheckpoint,
47
+ "LogModelTracker": LogModelTracker,
48
+ "LogHyperparameters": LogHyperparameters,
49
+ "SaveBestCheckpoint": SaveBestCheckpoint,
50
+ }
51
+
52
+ def __init__(self, config: TrainConfig):
53
+ """
54
+ Initializes the CallbacksManager with a training configuration.
55
+
56
+ Args:
57
+ config (TrainConfig): The training configuration object.
58
+ """
59
+ self.config = config
60
+ tracking_tool = self.config.tracking_tool
61
+ self.writer = get_writer(tracking_tool)
62
+ self.callbacks: list[Callback] = []
63
+
64
+ @classmethod
65
+ def set_use_grad(cls, use_grad: bool) -> None:
66
+ """
67
+ Sets whether gradients should be used in callbacks.
68
+
69
+ Args:
70
+ use_grad (bool): A boolean indicating whether to use gradients.
71
+ """
72
+ if not isinstance(use_grad, bool):
73
+ raise ValueError("use_grad must be a boolean value.")
74
+ cls.use_grad = use_grad
75
+
76
+ def initialize_callbacks(self) -> None:
77
+ """Initializes and adds the necessary callbacks based on the configuration."""
78
+ # Train Start
79
+ self.callbacks = copy.deepcopy(self.config.callbacks)
80
+ self.add_callback("PlotMetrics", "train_start")
81
+ if self.config.val_every:
82
+ self.add_callback("WriteMetrics", "train_start")
83
+ # only save the first checkpoint if not checkpoint_best_only
84
+ if not self.config.checkpoint_best_only:
85
+ self.add_callback("SaveCheckpoint", "train_start")
86
+
87
+ # Checkpointing
88
+ if self.config.checkpoint_best_only:
89
+ self.add_callback("SaveBestCheckpoint", "val_epoch_end", self.config.val_every)
90
+ elif self.config.checkpoint_every:
91
+ self.add_callback("SaveCheckpoint", "train_epoch_end", self.config.checkpoint_every)
92
+
93
+ # Printing
94
+ if self.config.verbose and self.config.print_every:
95
+ self.add_callback("PrintMetrics", "train_epoch_end", self.config.print_every)
96
+
97
+ # Plotting
98
+ if self.config.plot_every:
99
+ self.add_callback("PlotMetrics", "train_epoch_end", self.config.plot_every)
100
+
101
+ # Writing
102
+ if self.config.write_every:
103
+ self.add_callback("WriteMetrics", "train_epoch_end", self.config.write_every)
104
+ if self.config.val_every:
105
+ self.add_callback("WriteMetrics", "val_epoch_end", self.config.val_every)
106
+
107
+ # Train End
108
+ # Hyperparameters
109
+ if self.config.hyperparams:
110
+ self.add_callback("LogHyperparameters", "train_end")
111
+ # Log model
112
+ if self.config.log_model:
113
+ self.add_callback("LogModelTracker", "train_end")
114
+ if self.config.plot_every:
115
+ self.add_callback("PlotMetrics", "train_end")
116
+ # only save the last checkpoint if not checkpoint_best_only
117
+ if not self.config.checkpoint_best_only:
118
+ self.add_callback("SaveCheckpoint", "train_end")
119
+ self.add_callback("WriteMetrics", "train_end")
120
+
121
+ def add_callback(
122
+ self, callback: str | Callback, on: str | TrainingStage, called_every: int = 1
123
+ ) -> None:
124
+ """
125
+ Adds a callback to the manager.
126
+
127
+ Args:
128
+ callback (str | Callback): The callback instance or name.
129
+ on (str | TrainingStage): The event on which to trigger the callback.
130
+ called_every (int): Frequency of callback calls in terms of iterations.
131
+ """
132
+ if isinstance(callback, str):
133
+ callback_class = self.callback_map.get(callback)
134
+ if callback_class:
135
+ # Create an instance of the callback class
136
+ callback_instance = callback_class(on=on, called_every=called_every)
137
+ self.callbacks.append(callback_instance)
138
+ else:
139
+ logger.warning(f"Callback '{callback}' not recognized and will be skipped.")
140
+ elif isinstance(callback, Callback):
141
+ callback.on = on
142
+ callback.called_every = called_every
143
+ self.callbacks.append(callback)
144
+ else:
145
+ logger.warning(
146
+ f"Invalid callback type: {type(callback)}. Expected str or Callback instance."
147
+ )
148
+
149
+ def run_callbacks(self, trainer: Any) -> Any:
150
+ """
151
+ Runs callbacks that match the current training state.
152
+
153
+ Args:
154
+ trainer (Any): The training object managing the training process.
155
+
156
+ Returns:
157
+ Any: Results of the executed callbacks.
158
+ """
159
+ return [
160
+ callback(
161
+ when=trainer.training_stage, trainer=trainer, config=self.config, writer=self.writer
162
+ )
163
+ for callback in self.callbacks
164
+ if callback.on == trainer.training_stage
165
+ ]
166
+
167
+ def start_training(self, trainer: Any) -> None:
168
+ """
169
+ Initializes callbacks and starts the training process.
170
+
171
+ Args:
172
+ trainer (Any): The training object managing the training process.
173
+ """
174
+ # Clear all handlers from the logger
175
+ self.initialize_callbacks()
176
+
177
+ trainer.opt_result = OptimizeResult(trainer.global_step, trainer.model, trainer.optimizer)
178
+ trainer.is_last_iteration = False
179
+
180
+ # Load checkpoint only if a new subfolder was NOT recently added
181
+ if not trainer.config_manager._added_new_subfolder:
182
+ load_checkpoint_callback = LoadCheckpoint(on="train_start", called_every=1)
183
+ loaded_result = load_checkpoint_callback.run_callback(
184
+ trainer=trainer,
185
+ config=self.config,
186
+ writer=None, # type: ignore[arg-type]
187
+ )
188
+
189
+ if loaded_result:
190
+ model, optimizer, init_iter = loaded_result
191
+ if isinstance(init_iter, (int, str)):
192
+ trainer.model = model
193
+ trainer.optimizer = optimizer
194
+ trainer.global_step = (
195
+ init_iter if isinstance(init_iter, int) else trainer.global_step
196
+ )
197
+ trainer.current_epoch = (
198
+ init_iter if isinstance(init_iter, int) else trainer.current_epoch
199
+ )
200
+ trainer.opt_result = OptimizeResult(trainer.current_epoch, model, optimizer)
201
+ logger.debug(f"Loaded model and optimizer from {self.config.log_folder}")
202
+
203
+ # Setup writer
204
+ self.writer.open(self.config, iteration=trainer.global_step)
205
+
206
+ def end_training(self, trainer: Any) -> None:
207
+ """
208
+ Cleans up and finalizes the training process.
209
+
210
+ Args:
211
+ trainer (Any): The training object managing the training process.
212
+ """
213
+ if self.writer:
214
+ self.writer.close()