qadence 1.8.0__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. qadence/__init__.py +1 -1
  2. qadence/analog/parse_analog.py +1 -2
  3. qadence/backends/gpsr.py +8 -2
  4. qadence/backends/pulser/backend.py +7 -23
  5. qadence/backends/pyqtorch/backend.py +80 -5
  6. qadence/backends/pyqtorch/config.py +10 -3
  7. qadence/backends/pyqtorch/convert_ops.py +63 -2
  8. qadence/blocks/primitive.py +1 -0
  9. qadence/execution.py +0 -2
  10. qadence/log_config.yaml +10 -0
  11. qadence/measurements/shadow.py +97 -128
  12. qadence/measurements/utils.py +2 -2
  13. qadence/mitigations/readout.py +12 -6
  14. qadence/ml_tools/__init__.py +4 -8
  15. qadence/ml_tools/callbacks/__init__.py +30 -0
  16. qadence/ml_tools/callbacks/callback.py +451 -0
  17. qadence/ml_tools/callbacks/callbackmanager.py +214 -0
  18. qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
  19. qadence/ml_tools/callbacks/writer_registry.py +430 -0
  20. qadence/ml_tools/config.py +132 -258
  21. qadence/ml_tools/data.py +7 -3
  22. qadence/ml_tools/loss/__init__.py +10 -0
  23. qadence/ml_tools/loss/loss.py +87 -0
  24. qadence/ml_tools/optimize_step.py +45 -10
  25. qadence/ml_tools/stages.py +46 -0
  26. qadence/ml_tools/train_utils/__init__.py +7 -0
  27. qadence/ml_tools/train_utils/base_trainer.py +548 -0
  28. qadence/ml_tools/train_utils/config_manager.py +184 -0
  29. qadence/ml_tools/trainer.py +692 -0
  30. qadence/model.py +1 -1
  31. qadence/noise/__init__.py +2 -2
  32. qadence/noise/protocols.py +18 -53
  33. qadence/operations/ham_evo.py +87 -26
  34. qadence/transpile/noise.py +12 -5
  35. qadence/types.py +15 -3
  36. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/METADATA +3 -4
  37. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/RECORD +39 -32
  38. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/WHEEL +1 -1
  39. qadence/ml_tools/printing.py +0 -154
  40. qadence/ml_tools/train_grad.py +0 -395
  41. qadence/ml_tools/train_no_grad.py +0 -199
  42. qadence/noise/readout.py +0 -218
  43. {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,19 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- import datetime
4
- import os
5
3
  from dataclasses import dataclass, field, fields
6
4
  from logging import getLogger
7
5
  from pathlib import Path
8
- from typing import Any, Callable, Type
9
- from uuid import uuid4
6
+ from typing import Callable, Type
10
7
 
11
8
  from sympy import Basic
12
- from torch import Tensor
13
9
 
14
10
  from qadence.blocks.analog import AnalogBlock
15
11
  from qadence.blocks.primitive import ParametricBlock
16
- from qadence.ml_tools.data import OptimizeResult
17
12
  from qadence.operations import RX, AnalogRX
18
13
  from qadence.parameters import Parameter
19
14
  from qadence.types import (
@@ -28,306 +23,185 @@ from qadence.types import (
28
23
 
29
24
  logger = getLogger(__file__)
30
25
 
31
- CallbackFunction = Callable[[OptimizeResult], None]
32
- CallbackConditionFunction = Callable[[OptimizeResult], bool]
33
-
34
-
35
- class Callback:
36
- """Callback functions are calling in train functions.
37
-
38
- Each callback function should take at least as first input
39
- an OptimizeResult instance.
40
-
41
- Note: when setting call_after_opt to True, we skip
42
- verifying iteration % called_every == 0.
43
-
44
- Attributes:
45
- callback (CallbackFunction): Callback function accepting an
46
- OptimizeResult as first argument.
47
- callback_condition (CallbackConditionFunction | None, optional): Function that
48
- conditions the call to callback. Defaults to None.
49
- modify_optimize_result (CallbackFunction | dict[str, Any] | None, optional):
50
- Function that modify the OptimizeResult before callback.
51
- For instance, one can change the `extra` (dict) argument to be used in callback.
52
- If a dict is provided, the `extra` field of OptimizeResult is updated with the dict.
53
- called_every (int, optional): Callback to be called each `called_every` epoch.
54
- Defaults to 1.
55
- If callback_condition is None, we set
56
- callback_condition to returns True when iteration % called_every == 0.
57
- call_before_opt (bool, optional): If true, callback is applied before training.
58
- Defaults to False.
59
- call_end_epoch (bool, optional): If true, callback is applied during training,
60
- after an epoch is performed. Defaults to True.
61
- call_after_opt (bool, optional): If true, callback is applied after training.
62
- Defaults to False.
63
- call_during_eval (bool, optional): If true, callback is applied during evaluation.
64
- Defaults to False.
65
- """
66
-
67
- def __init__(
68
- self,
69
- callback: CallbackFunction,
70
- callback_condition: CallbackConditionFunction | None = None,
71
- modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
72
- called_every: int = 1,
73
- call_before_opt: bool = False,
74
- call_end_epoch: bool = True,
75
- call_after_opt: bool = False,
76
- call_during_eval: bool = False,
77
- ) -> None:
78
- """Initialized Callback.
79
-
80
- Args:
81
- callback (CallbackFunction): Callback function accepting an
82
- OptimizeResult as ifrst argument.
83
- callback_condition (CallbackConditionFunction | None, optional): Function that
84
- conditions the call to callback. Defaults to None.
85
- modify_optimize_result (CallbackFunction | dict[str, Any] | None , optional):
86
- Function that modify the OptimizeResult before callback. If a dict
87
- is provided, this updates the `extra` field of OptimizeResult.
88
- called_every (int, optional): Callback to be called each `called_every` epoch.
89
- Defaults to 1.
90
- If callback_condition is None, we set
91
- callback_condition to returns True when iteration % called_every == 0.
92
- call_before_opt (bool, optional): If true, callback is applied before training.
93
- Defaults to False.
94
- call_end_epoch (bool, optional): If true, callback is applied during training,
95
- after an epoch is performed. Defaults to True.
96
- call_after_opt (bool, optional): If true, callback is applied after training.
97
- Defaults to False.
98
- call_during_eval (bool, optional): If true, callback is applied during evaluation.
99
- Defaults to False.
100
- """
101
- self.callback = callback
102
- self.call_before_opt = call_before_opt
103
- self.call_end_epoch = call_end_epoch
104
- self.call_after_opt = call_after_opt
105
- self.call_during_eval = call_during_eval
106
-
107
- if called_every <= 0:
108
- raise ValueError("Please provide a strictly positive `called_every` argument.")
109
- self.called_every = called_every
110
-
111
- if callback_condition is None:
112
- self.callback_condition = lambda opt_result: True
113
- else:
114
- self.callback_condition = callback_condition
115
-
116
- if modify_optimize_result is None:
117
- self.modify_optimize_result = lambda opt_result: opt_result
118
- elif isinstance(modify_optimize_result, dict):
119
-
120
- def update_extra(opt_result: OptimizeResult) -> OptimizeResult:
121
- opt_result.extra.update(modify_optimize_result)
122
- return opt_result
123
-
124
- self.modify_optimize_result = update_extra
125
- else:
126
- self.modify_optimize_result = modify_optimize_result
127
-
128
- def __call__(self, opt_result: OptimizeResult, is_last_iteration: bool = False) -> Any:
129
- """Apply callback if conditions are met.
130
-
131
- Note that the current result may be modified by specifying a function
132
- `modify_optimize_result` for instance to add inputs to the `extra` argument
133
- of the current OptimizeResult.
134
-
135
- Args:
136
- opt_result (OptimizeResult): Current result.
137
- is_last_iteration (bool, optional): When True,
138
- avoid verifying modulo. Defaults to False.
139
- Useful when call_after_opt is True.
140
-
141
- Returns:
142
- Any: The result of the callback.
143
- """
144
- opt_result = self.modify_optimize_result(opt_result)
145
- if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result):
146
- return self.callback(opt_result)
147
- if is_last_iteration and self.callback_condition(opt_result):
148
- return self.callback(opt_result)
149
-
150
-
151
- def run_callbacks(
152
- callback_iterable: list[Callback], opt_res: OptimizeResult, is_last_iteration: bool = False
153
- ) -> None:
154
- """Run a list of Callback given the current OptimizeResult.
155
-
156
- Used in train functions.
157
-
158
- Args:
159
- callback_iterable (list[Callback]): Iterable of Callbacks
160
- opt_res (OptimizeResult): Current optimization result,
161
- is_last_iteration (bool, optional): Whether we reached the last iteration or not.
162
- Defaults to False.
163
- """
164
- for callback in callback_iterable:
165
- callback(opt_res, is_last_iteration)
166
-
167
26
 
168
27
  @dataclass
169
28
  class TrainConfig:
170
- """Default config for the train function.
29
+ """Default configuration for the training process.
171
30
 
172
- The default value of
173
- each field can be customized with the constructor:
31
+ This class provides default settings for various aspects of the training loop,
32
+ such as logging, checkpointing, and validation. The default values for these
33
+ fields can be customized when an instance of `TrainConfig` is created.
174
34
 
35
+ Example:
175
36
  ```python exec="on" source="material-block" result="json"
176
37
  from qadence.ml_tools import TrainConfig
177
- c = TrainConfig(folder="/tmp/train")
38
+ c = TrainConfig(root_folder="/tmp/train")
178
39
  print(str(c)) # markdown-exec: hide
179
40
  ```
180
41
  """
181
42
 
182
43
  max_iter: int = 10000
183
- """Number of training iterations."""
184
- print_every: int = 1000
185
- """Print loss/metrics.
44
+ """Number of training iterations (epochs) to perform.
45
+
46
+ This defines the total number
47
+ of times the model will be updated.
186
48
 
187
- Set to 0 to disable
49
+ In case of InfiniteTensorDataset, each epoch will have 1 batch.
50
+ In case of TensorDataset, each epoch will have len(dataloader) batches.
188
51
  """
189
- write_every: int = 50
190
- """Write loss and metrics with the tracking tool.
191
52
 
192
- Set to 0 to disable
53
+ print_every: int = 0
54
+ """Frequency (in epochs) for printing loss and metrics to the console during training.
55
+
56
+ Set to 0 to disable this output, meaning that metrics and loss will not be printed
57
+ during training.
193
58
  """
194
- checkpoint_every: int = 5000
195
- """Write model/optimizer checkpoint.
196
59
 
197
- Set to 0 to disable
60
+ write_every: int = 0
61
+ """Frequency (in epochs) for writing loss and metrics using the tracking tool during training.
62
+
63
+ Set to 0 to disable this logging, which prevents metrics from being logged to the tracking tool.
64
+ Note that the metrics will always be written at the end of training regardless of this setting.
198
65
  """
199
- plot_every: int = 5000
200
- """Write figures.
201
66
 
202
- Set to 0 to disable
67
+ checkpoint_every: int = 0
68
+ """Frequency (in epochs) for saving model and optimizer checkpoints during training.
69
+
70
+ Set to 0 to disable checkpointing. This helps in resuming training or recovering
71
+ models.
72
+ Note that setting checkpoint_best_only = True will disable this and only best checkpoints will
73
+ be saved.
74
+ """
75
+
76
+ plot_every: int = 0
77
+ """Frequency (in epochs) for generating and saving figures during training.
78
+
79
+ Set to 0 to disable plotting.
203
80
  """
204
- callbacks: list[Callback] = field(default_factory=lambda: list())
205
- """List of callbacks."""
81
+
82
+ callbacks: list = field(default_factory=lambda: list())
83
+ """List of callbacks to execute during training.
84
+
85
+ Callbacks can be used for
86
+ custom behaviors, such as early stopping, custom logging, or other actions
87
+ triggered at specific events.
88
+ """
89
+
206
90
  log_model: bool = False
207
- """Logs a serialised version of the model."""
208
- folder: Path | None = None
209
- """Checkpoint/tensorboard logs folder."""
91
+ """Whether to log a serialized version of the model.
92
+
93
+ When set to `True`, the
94
+ model's state will be logged, useful for model versioning and reproducibility.
95
+ """
96
+
97
+ root_folder: Path = Path("./qml_logs")
98
+ """The root folder for saving checkpoints and tensorboard logs.
99
+
100
+ The default path is "./qml_logs"
101
+
102
+ This can be set to a specific directory where training artifacts are to be stored.
103
+ Checkpoints will be saved inside a subfolder in this directory. Subfolders will be
104
+ created based on `create_subfolder_per_run` argument.
105
+ """
106
+
210
107
  create_subfolder_per_run: bool = False
211
- """Checkpoint/tensorboard logs stored in subfolder with name `<timestamp>_<PID>`.
108
+ """Whether to create a subfolder for each run, named `<id>_<timestamp>_<PID>`.
109
+
110
+ This ensures logs and checkpoints from different runs do not overwrite each other,
111
+ which is helpful for rapid prototyping. If `False`, training will resume from
112
+ the latest checkpoint if one exists in the specified log folder.
113
+ """
114
+
115
+ log_folder: Path = Path("./")
116
+ """The log folder for saving checkpoints and tensorboard logs.
212
117
 
213
- Prevents continuing from previous checkpoint, useful for fast prototyping.
118
+ This stores the path where all logs and checkpoints are being saved
119
+ for this training session. `log_folder` takes precedence over `root_folder` and
120
+ `create_subfolder_per_run` arguments. If the user specifies a log_folder,
121
+ all checkpoints will be saved in this folder and `root_folder` argument
122
+ will not be used.
214
123
  """
124
+
215
125
  checkpoint_best_only: bool = False
216
- """Write model/optimizer checkpoint only if a metric has improved."""
217
- val_every: int | None = None
218
- """Calculate validation metric.
126
+ """If `True`, checkpoints are only saved if there is an improvement in the.
219
127
 
220
- If None, validation check is not performed.
128
+ validation metric. This conserves storage by only keeping the best models.
129
+
130
+ validation_criterion is required when this is set to True.
221
131
  """
132
+
133
+ val_every: int = 0
134
+ """Frequency (in epochs) for performing validation.
135
+
136
+ If set to 0, validation is not performed.
137
+ Note that metrics from validation are always written, regardless of the `write_every` setting.
138
+ Note that initial validation happens at the start of training (when val_every > 0)
139
+ For initial validation - initial metrics are written.
140
+ - checkpoint is saved (when checkpoint_best_only = False)
141
+ """
142
+
222
143
  val_epsilon: float = 1e-5
223
- """Safety margin to check if validation loss is smaller than the lowest.
144
+ """A small safety margin used to compare the current validation loss with the.
224
145
 
225
- validation loss across previous iterations.
146
+ best previous validation loss. This is used to determine improvements in metrics.
226
147
  """
148
+
227
149
  validation_criterion: Callable | None = None
228
- """A boolean function which evaluates a given validation metric is satisfied."""
150
+ """A function to evaluate whether a given validation metric meets a desired condition.
151
+
152
+ The validation_criterion has the following format:
153
+ def validation_criterion(val_loss: float, best_val_loss: float, val_epsilon: float) -> bool:
154
+ # process
155
+
156
+ If `None`, no custom validation criterion is applied.
157
+ """
158
+
229
159
  trainstop_criterion: Callable | None = None
230
- """A boolean function which evaluates a given training stopping metric is satisfied."""
231
- batch_size: int = 1
232
- """The batch_size to use when passing a list/tuple of torch.Tensors."""
233
- verbose: bool = True
234
- """Whether or not to print out metrics values during training."""
235
- tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
236
- """The tracking tool of choice."""
237
- hyperparams: dict = field(default_factory=dict)
238
- """Hyperparameters to track."""
239
- plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore
240
- """Functions for in-train plotting."""
241
-
242
- # tensorboard only allows for certain types as hyperparameters
243
- _tb_allowed_hyperparams_types: tuple = field(
244
- default=(int, float, str, bool, Tensor), init=False, repr=False
245
- )
246
-
247
- def _filter_tb_hyperparams(self) -> None:
248
- keys_to_remove = [
249
- key
250
- for key, value in self.hyperparams.items()
251
- if not isinstance(value, TrainConfig._tb_allowed_hyperparams_types)
252
- ]
253
- if keys_to_remove:
254
- logger.warning(
255
- f"Tensorboard cannot log the following hyperparameters: {keys_to_remove}."
256
- )
257
- for key in keys_to_remove:
258
- self.hyperparams.pop(key)
160
+ """A function to determine if the training process should stop based on a.
259
161
 
260
- def __post_init__(self) -> None:
261
- if self.folder:
262
- if isinstance(self.folder, str): # type: ignore [unreachable]
263
- self.folder = Path(self.folder) # type: ignore [unreachable]
264
- if self.create_subfolder_per_run:
265
- subfoldername = (
266
- datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + "_" + hex(os.getpid())[2:]
267
- )
268
- self.folder = self.folder / subfoldername
269
- if self.trainstop_criterion is None:
270
- self.trainstop_criterion = lambda x: x <= self.max_iter
271
- if self.validation_criterion is None:
272
- self.validation_criterion = lambda *x: False
273
- if self.hyperparams and self.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
274
- self._filter_tb_hyperparams()
275
- if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
276
- self._mlflow_config = MLFlowConfig()
277
- if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW:
278
- logger.warning("In-training plots are only available with mlflow tracking.")
279
- if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW:
280
- logger.warning("Tracking with mlflow, but no plotting functions provided.")
281
-
282
- @property
283
- def mlflow_config(self) -> MLFlowConfig:
284
- if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
285
- return self._mlflow_config
286
- else:
287
- raise AttributeError(
288
- "mlflow_config is available only for with the mlflow tracking tool."
289
- )
162
+ specific stopping metric. If `None`, training continues until `max_iter` is reached.
163
+ """
290
164
 
165
+ batch_size: int = 1
166
+ """The batch size to use when processing a list or tuple of torch.Tensors.
291
167
 
292
- class MLFlowConfig:
168
+ This specifies how many samples are processed in each training iteration.
293
169
  """
294
- Configuration for mlflow tracking.
295
170
 
296
- Example:
171
+ verbose: bool = True
172
+ """Whether to print metrics and status messages during training.
297
173
 
298
- export MLFLOW_TRACKING_URI=tracking_uri
299
- export MLFLOW_EXPERIMENT=experiment_name
300
- export MLFLOW_RUN_NAME=run_name
174
+ If `True`, detailed metrics and status updates will be displayed in the console.
301
175
  """
302
176
 
303
- def __init__(self) -> None:
304
- import mlflow
177
+ tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
178
+ """The tool used for tracking training progress and logging metrics.
305
179
 
306
- self.tracking_uri: str = os.getenv("MLFLOW_TRACKING_URI", "")
307
- """The URI of the mlflow tracking server.
180
+ Options include tools like TensorBoard, which help visualize and monitor
181
+ model training.
182
+ """
308
183
 
309
- An empty string, or a local file path, prefixed with file:/.
310
- Data is stored locally at the provided file (or ./mlruns if empty).
311
- """
184
+ hyperparams: dict = field(default_factory=dict)
185
+ """A dictionary of hyperparameters to be tracked.
312
186
 
313
- self.experiment_name: str = os.getenv("MLFLOW_EXPERIMENT", str(uuid4()))
314
- """The name of the experiment.
187
+ This can include learning rates,
188
+ regularization parameters, or any other training-related configurations.
189
+ """
315
190
 
316
- If None or empty, a new experiment is created with a random UUID.
317
- """
191
+ plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore
192
+ """Functions used for in-training plotting.
318
193
 
319
- self.run_name: str = os.getenv("MLFLOW_RUN_NAME", str(uuid4()))
320
- """The name of the run."""
194
+ These are called to generate
195
+ plots that are logged or saved at specified intervals.
196
+ """
321
197
 
322
- mlflow.set_tracking_uri(self.tracking_uri)
198
+ _subfolders: list = field(default_factory=list)
199
+ """List of subfolders used for logging different runs using the same config inside the.
323
200
 
324
- # activate existing or create experiment
325
- exp_filter_string = f"name = '{self.experiment_name}'"
326
- if not mlflow.search_experiments(filter_string=exp_filter_string):
327
- mlflow.create_experiment(name=self.experiment_name)
201
+ root folder.
328
202
 
329
- self.experiment = mlflow.set_experiment(self.experiment_name)
330
- self.run = mlflow.start_run(run_name=self.run_name, nested=False)
203
+ Each subfolder is of structure `<id>_<timestamp>_<PID>`.
204
+ """
331
205
 
332
206
 
333
207
  @dataclass
qadence/ml_tools/data.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import random
3
4
  from dataclasses import dataclass, field
4
5
  from functools import singledispatch
5
- from itertools import cycle
6
6
  from typing import Any, Iterator
7
7
 
8
8
  from nevergrad.optimization.base import Optimizer as NGOptimizer
@@ -72,13 +72,17 @@ class InfiniteTensorDataset(IterableDataset):
72
72
  ```
73
73
  """
74
74
  self.tensors = tensors
75
+ self.indices = list(range(self.tensors[0].size(0)))
75
76
 
76
77
  def __iter__(self) -> Iterator:
77
78
  if len(set([t.size(0) for t in self.tensors])) != 1:
78
79
  raise ValueError("Size of first dimension must be the same for all tensors.")
79
80
 
80
- for idx in cycle(range(self.tensors[0].size(0))):
81
- yield tuple(t[idx] for t in self.tensors)
81
+ # Shuffle the indices for every full pass
82
+ random.shuffle(self.indices)
83
+ while True:
84
+ for idx in self.indices:
85
+ yield tuple(t[idx] for t in self.tensors)
82
86
 
83
87
 
84
88
  def to_dataloader(*tensors: Tensor, batch_size: int = 1, infinite: bool = False) -> DataLoader:
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from .loss import cross_entropy_loss, get_loss_fn, mse_loss
4
+
5
+ # Modules to be automatically added to the qadence.ml_tools.loss namespace
6
+ __all__ = [
7
+ "cross_entropy_loss",
8
+ "get_loss_fn",
9
+ "mse_loss",
10
+ ]
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def mse_loss(
10
+ model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]
11
+ ) -> tuple[torch.Tensor, dict[str, float]]:
12
+ """Computes the Mean Squared Error (MSE) loss between model predictions and targets.
13
+
14
+ Args:
15
+ model (nn.Module): The PyTorch model used for generating predictions.
16
+ batch (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
17
+ - inputs (torch.Tensor): The input data.
18
+ - targets (torch.Tensor): The ground truth labels.
19
+
20
+ Returns:
21
+ Tuple[torch.Tensor, dict[str, float]]:
22
+ - loss (torch.Tensor): The computed MSE loss value.
23
+ - metrics (dict[str, float]): A dictionary with the MSE loss value.
24
+ """
25
+ criterion = nn.MSELoss()
26
+ inputs, targets = batch
27
+ outputs = model(inputs)
28
+ loss = criterion(outputs, targets)
29
+
30
+ metrics = {"mse": loss}
31
+ return loss, metrics
32
+
33
+
34
+ def cross_entropy_loss(
35
+ model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]
36
+ ) -> tuple[torch.Tensor, dict[str, float]]:
37
+ """Computes the Cross Entropy loss between model predictions and targets.
38
+
39
+ Args:
40
+ model (nn.Module): The PyTorch model used for generating predictions.
41
+ batch (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
42
+ - inputs (torch.Tensor): The input data.
43
+ - targets (torch.Tensor): The ground truth labels.
44
+
45
+ Returns:
46
+ Tuple[torch.Tensor, dict[str, float]]:
47
+ - loss (torch.Tensor): The computed Cross Entropy loss value.
48
+ - metrics (dict[str, float]): A dictionary with the Cross Entropy loss value.
49
+ """
50
+ criterion = nn.CrossEntropyLoss()
51
+ inputs, targets = batch
52
+ outputs = model(inputs)
53
+ loss = criterion(outputs, targets)
54
+
55
+ metrics = {"cross_entropy": loss}
56
+ return loss, metrics
57
+
58
+
59
+ def get_loss_fn(loss_fn: str | Callable | None) -> Callable:
60
+ """
61
+ Returns the appropriate loss function based on the input argument.
62
+
63
+ Args:
64
+ loss_fn (str | Callable | None): The loss function to use.
65
+ - If `loss_fn` is a callable, it will be returned directly.
66
+ - If `loss_fn` is a string, it should be one of:
67
+ - "mse": Returns the `mse_loss` function.
68
+ - "cross_entropy": Returns the `cross_entropy_loss` function.
69
+ - If `loss_fn` is `None`, the default `mse_loss` function will be returned.
70
+
71
+ Returns:
72
+ Callable: The corresponding loss function.
73
+
74
+ Raises:
75
+ ValueError: If `loss_fn` is a string but not a supported loss function name.
76
+ """
77
+ if callable(loss_fn):
78
+ return loss_fn
79
+ elif isinstance(loss_fn, str):
80
+ if loss_fn == "mse":
81
+ return mse_loss
82
+ elif loss_fn == "cross_entropy":
83
+ return cross_entropy_loss
84
+ else:
85
+ raise ValueError(f"Unsupported loss function: {loss_fn}")
86
+ else:
87
+ return mse_loss
@@ -2,11 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Callable
4
4
 
5
+ import nevergrad as ng
5
6
  import torch
6
7
  from torch.nn import Module
7
8
  from torch.optim import Optimizer
8
9
 
9
10
  from qadence.ml_tools.data import data_to_device
11
+ from qadence.ml_tools.parameters import set_parameters
12
+ from qadence.ml_tools.tensors import promote_to_tensor
10
13
 
11
14
 
12
15
  def optimize_step(
@@ -19,21 +22,21 @@ def optimize_step(
19
22
  ) -> tuple[torch.Tensor | float, dict | None]:
20
23
  """Default Torch optimize step with closure.
21
24
 
22
- This is the default optimization step which should work for most
23
- of the standard use cases of optimization of Torch models
25
+ This is the default optimization step.
24
26
 
25
27
  Args:
26
- model (Module): The input model
27
- optimizer (Optimizer): The chosen Torch optimizer
28
+ model (Module): The input model to be optimized.
29
+ optimizer (Optimizer): The chosen Torch optimizer.
28
30
  loss_fn (Callable): A custom loss function
29
- xs (dict | list | torch.Tensor | None): the input data. If None it means
30
- that the given model does not require any input data
31
- device (torch.device): A target device to run computation on.
32
- dtype (torch.dtype): Data type for xs conversion.
31
+ that returns the loss value and a dictionary of metrics.
32
+ xs (dict | list | Tensor | None): The input data. If None, it means
33
+ the given model does not require any input data.
34
+ device (torch.device): A target device to run computations on.
35
+ dtype (torch.dtype): Data type for `xs` conversion.
33
36
 
34
37
  Returns:
35
- tuple: tuple containing the computed loss value, and a dictionary with
36
- the collected metrics.
38
+ tuple[Tensor | float, dict | None]: A tuple containing the computed loss value
39
+ and a dictionary with collected metrics.
37
40
  """
38
41
 
39
42
  loss, metrics = None, {}
@@ -52,3 +55,35 @@ def optimize_step(
52
55
  optimizer.step(closure)
53
56
  # return the loss/metrics that are being mutated inside the closure...
54
57
  return loss, metrics
58
+
59
+
60
+ def update_ng_parameters(
61
+ model: Module,
62
+ optimizer: ng.optimizers.Optimizer,
63
+ loss_fn: Callable[[Module, torch.Tensor | None], tuple[float, dict]],
64
+ data: torch.Tensor | None,
65
+ ng_params: ng.p.Array,
66
+ ) -> tuple[float, dict, ng.p.Array]:
67
+ """Update the model parameters using Nevergrad.
68
+
69
+ This function integrates Nevergrad for derivative-free optimization.
70
+
71
+ Args:
72
+ model (Module): The PyTorch model to be optimized.
73
+ optimizer (ng.optimizers.Optimizer): A Nevergrad optimizer instance.
74
+ loss_fn (Callable[[Module, Tensor | None], tuple[float, dict]]): A custom loss function
75
+ that returns the loss value and a dictionary of metrics.
76
+ data (Tensor | None): Input data for the model. If None, it means the model does
77
+ not require input data.
78
+ ng_params (ng.p.Array): The current set of parameters managed by Nevergrad.
79
+
80
+ Returns:
81
+ tuple[float, dict, ng.p.Array]: A tuple containing the computed loss value,
82
+ a dictionary of metrics, and the updated Nevergrad parameters.
83
+ """
84
+ loss, metrics = loss_fn(model, data) # type: ignore[misc]
85
+ optimizer.tell(ng_params, float(loss))
86
+ ng_params = optimizer.ask() # type: ignore[assignment]
87
+ params = promote_to_tensor(ng_params.value, requires_grad=False)
88
+ set_parameters(model, params)
89
+ return loss, metrics, ng_params