qadence 1.7.4__py3-none-any.whl → 1.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/analog/addressing.py +7 -3
- qadence/backends/api.py +9 -8
- qadence/backends/gpsr.py +18 -2
- qadence/backends/horqrux/convert_ops.py +1 -1
- qadence/backends/pyqtorch/backend.py +8 -11
- qadence/backends/pyqtorch/convert_ops.py +100 -123
- qadence/backends/utils.py +1 -1
- qadence/blocks/composite.py +5 -3
- qadence/blocks/utils.py +36 -2
- qadence/constructors/utils.py +26 -26
- qadence/engines/jax/differentiable_expectation.py +1 -1
- qadence/engines/torch/differentiable_expectation.py +17 -6
- qadence/extensions.py +28 -8
- qadence/ml_tools/__init__.py +2 -1
- qadence/ml_tools/config.py +131 -25
- qadence/ml_tools/constructors.py +39 -33
- qadence/ml_tools/data.py +27 -1
- qadence/ml_tools/optimize_step.py +3 -2
- qadence/ml_tools/saveload.py +3 -2
- qadence/ml_tools/train_grad.py +154 -94
- qadence/ml_tools/train_no_grad.py +86 -40
- qadence/model.py +47 -3
- qadence/types.py +2 -2
- {qadence-1.7.4.dist-info → qadence-1.7.6.dist-info}/METADATA +4 -4
- {qadence-1.7.4.dist-info → qadence-1.7.6.dist-info}/RECORD +27 -27
- {qadence-1.7.4.dist-info → qadence-1.7.6.dist-info}/WHEEL +0 -0
- {qadence-1.7.4.dist-info → qadence-1.7.6.dist-info}/licenses/LICENSE +0 -0
@@ -29,10 +29,11 @@ def optimize_step(
|
|
29
29
|
xs (dict | list | torch.Tensor | None): the input data. If None it means
|
30
30
|
that the given model does not require any input data
|
31
31
|
device (torch.device): A target device to run computation on.
|
32
|
+
dtype (torch.dtype): Data type for xs conversion.
|
32
33
|
|
33
34
|
Returns:
|
34
|
-
tuple: tuple containing the
|
35
|
-
the collected metrics
|
35
|
+
tuple: tuple containing the computed loss value, and a dictionary with
|
36
|
+
the collected metrics.
|
36
37
|
"""
|
37
38
|
|
38
39
|
loss, metrics = None, {}
|
qadence/ml_tools/saveload.py
CHANGED
@@ -72,7 +72,8 @@ def write_checkpoint(
|
|
72
72
|
device = None
|
73
73
|
try:
|
74
74
|
# We extract the device from the pyqtorch native circuit
|
75
|
-
device =
|
75
|
+
device = model.device if isinstance(QuantumModel, QNN) else next(model.parameters()).device
|
76
|
+
device = str(device).split(":")[0] # in case of using several CUDA devices
|
76
77
|
except Exception as e:
|
77
78
|
msg = (
|
78
79
|
f"Unable to identify in which device the QuantumModel is stored due to {e}."
|
@@ -132,7 +133,7 @@ def load_model(
|
|
132
133
|
try:
|
133
134
|
iteration, model_dict = torch.load(folder / model_ckpt_name, *args, **kwargs)
|
134
135
|
if isinstance(model, (QuantumModel, QNN)):
|
135
|
-
model.
|
136
|
+
model.load_params_from_dict(model_dict)
|
136
137
|
elif isinstance(model, Module):
|
137
138
|
model.load_state_dict(model_dict, strict=True)
|
138
139
|
# Load model to a specific gpu device if specified
|
qadence/ml_tools/train_grad.py
CHANGED
@@ -3,16 +3,10 @@ from __future__ import annotations
|
|
3
3
|
import importlib
|
4
4
|
import math
|
5
5
|
from logging import getLogger
|
6
|
-
from typing import Callable, Union
|
7
|
-
|
8
|
-
from rich.progress import
|
9
|
-
|
10
|
-
Progress,
|
11
|
-
TaskProgressColumn,
|
12
|
-
TextColumn,
|
13
|
-
TimeRemainingColumn,
|
14
|
-
)
|
15
|
-
from torch import complex128, float32, float64
|
6
|
+
from typing import Any, Callable, Union
|
7
|
+
|
8
|
+
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
9
|
+
from torch import Tensor, complex128, float32, float64
|
16
10
|
from torch import device as torch_device
|
17
11
|
from torch import dtype as torch_dtype
|
18
12
|
from torch.nn import DataParallel, Module
|
@@ -20,8 +14,8 @@ from torch.optim import Optimizer
|
|
20
14
|
from torch.utils.data import DataLoader
|
21
15
|
from torch.utils.tensorboard import SummaryWriter
|
22
16
|
|
23
|
-
from qadence.ml_tools.config import TrainConfig
|
24
|
-
from qadence.ml_tools.data import DictDataLoader, data_to_device
|
17
|
+
from qadence.ml_tools.config import Callback, TrainConfig
|
18
|
+
from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device
|
25
19
|
from qadence.ml_tools.optimize_step import optimize_step
|
26
20
|
from qadence.ml_tools.printing import (
|
27
21
|
log_model_tracker,
|
@@ -166,107 +160,178 @@ def train(
|
|
166
160
|
|
167
161
|
best_val_loss = math.inf
|
168
162
|
|
163
|
+
if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))):
|
164
|
+
raise NotImplementedError(
|
165
|
+
f"Unsupported dataloader type: {type(dataloader)}. "
|
166
|
+
"You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
|
167
|
+
)
|
168
|
+
|
169
|
+
def next_loss_iter(dl_iter: Union[None, DataLoader, DictDataLoader]) -> Any:
|
170
|
+
"""Get loss on the next batch of a dataloader.
|
171
|
+
|
172
|
+
loaded on device if not None.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
dl_iter (Union[None, DataLoader, DictDataLoader]): Dataloader.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
Any: Loss value
|
179
|
+
"""
|
180
|
+
xs = next(dl_iter) if dl_iter is not None else None
|
181
|
+
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
182
|
+
return loss_fn(model, xs_to_device)
|
183
|
+
|
184
|
+
# populate callbacks with already available internal functions
|
185
|
+
# printing, writing and plotting
|
186
|
+
callbacks = config.callbacks
|
187
|
+
|
188
|
+
# printing
|
189
|
+
if config.verbose and config.print_every > 0:
|
190
|
+
# Note that the loss returned by optimize_step
|
191
|
+
# is the value before doing the training step
|
192
|
+
# which is printed accordingly by the previous iteration number
|
193
|
+
callbacks += [
|
194
|
+
Callback(
|
195
|
+
lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration - 1),
|
196
|
+
called_every=config.print_every,
|
197
|
+
call_after_opt=True,
|
198
|
+
)
|
199
|
+
]
|
200
|
+
|
201
|
+
# plotting
|
202
|
+
callbacks += [
|
203
|
+
Callback(
|
204
|
+
lambda opt_res: plot_tracker(
|
205
|
+
writer,
|
206
|
+
opt_res.model,
|
207
|
+
opt_res.iteration,
|
208
|
+
config.plotting_functions,
|
209
|
+
tracking_tool=config.tracking_tool,
|
210
|
+
),
|
211
|
+
called_every=config.plot_every,
|
212
|
+
call_before_opt=True,
|
213
|
+
)
|
214
|
+
]
|
215
|
+
|
216
|
+
# writing metrics
|
217
|
+
callbacks += [
|
218
|
+
Callback(
|
219
|
+
lambda opt_res: write_tracker(
|
220
|
+
writer,
|
221
|
+
opt_res.loss,
|
222
|
+
opt_res.metrics,
|
223
|
+
opt_res.iteration,
|
224
|
+
tracking_tool=config.tracking_tool,
|
225
|
+
),
|
226
|
+
called_every=config.write_every,
|
227
|
+
call_before_opt=False,
|
228
|
+
call_after_opt=True,
|
229
|
+
call_during_eval=True,
|
230
|
+
)
|
231
|
+
]
|
232
|
+
|
233
|
+
# checkpointing
|
234
|
+
if config.folder and config.checkpoint_every > 0 and not config.checkpoint_best_only:
|
235
|
+
callbacks += [
|
236
|
+
Callback(
|
237
|
+
lambda opt_res: write_checkpoint(
|
238
|
+
config.folder, # type: ignore[arg-type]
|
239
|
+
opt_res.model,
|
240
|
+
opt_res.optimizer,
|
241
|
+
opt_res.iteration,
|
242
|
+
),
|
243
|
+
called_every=config.checkpoint_every,
|
244
|
+
call_before_opt=False,
|
245
|
+
call_after_opt=True,
|
246
|
+
)
|
247
|
+
]
|
248
|
+
|
249
|
+
if config.folder and config.checkpoint_best_only:
|
250
|
+
callbacks += [
|
251
|
+
Callback(
|
252
|
+
lambda opt_res: write_checkpoint(
|
253
|
+
config.folder, # type: ignore[arg-type]
|
254
|
+
opt_res.model,
|
255
|
+
opt_res.optimizer,
|
256
|
+
"best",
|
257
|
+
),
|
258
|
+
called_every=config.checkpoint_every,
|
259
|
+
call_before_opt=True,
|
260
|
+
call_after_opt=True,
|
261
|
+
call_during_eval=True,
|
262
|
+
)
|
263
|
+
]
|
264
|
+
|
265
|
+
def run_callbacks(callback_iterable: list[Callback], opt_res: OptimizeResult) -> None:
|
266
|
+
for callback in callback_iterable:
|
267
|
+
callback(opt_res)
|
268
|
+
|
269
|
+
callbacks_before_opt = [
|
270
|
+
callback
|
271
|
+
for callback in callbacks
|
272
|
+
if callback.call_before_opt and not callback.call_during_eval
|
273
|
+
]
|
274
|
+
callbacks_before_opt_eval = [
|
275
|
+
callback for callback in callbacks if callback.call_before_opt and callback.call_during_eval
|
276
|
+
]
|
277
|
+
|
169
278
|
with progress:
|
170
279
|
dl_iter = iter(dataloader) if dataloader is not None else None
|
171
280
|
|
172
281
|
# Initial validation evaluation
|
173
282
|
try:
|
283
|
+
opt_result = OptimizeResult(init_iter, model, optimizer)
|
174
284
|
if perform_val:
|
175
285
|
dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None
|
176
|
-
|
177
|
-
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
178
|
-
best_val_loss, metrics = loss_fn(model, xs_to_device)
|
179
|
-
|
286
|
+
best_val_loss, metrics, *_ = next_loss_iter(dl_iter_val)
|
180
287
|
metrics["val_loss"] = best_val_loss
|
181
|
-
|
182
|
-
|
183
|
-
if config.folder:
|
184
|
-
if config.checkpoint_best_only:
|
185
|
-
write_checkpoint(config.folder, model, optimizer, iteration="best")
|
186
|
-
else:
|
187
|
-
write_checkpoint(config.folder, model, optimizer, init_iter)
|
288
|
+
opt_result.metrics = metrics
|
289
|
+
run_callbacks(callbacks_before_opt_eval, opt_result)
|
188
290
|
|
189
|
-
|
190
|
-
writer,
|
191
|
-
model,
|
192
|
-
init_iter,
|
193
|
-
config.plotting_functions,
|
194
|
-
tracking_tool=config.tracking_tool,
|
195
|
-
)
|
291
|
+
run_callbacks(callbacks_before_opt, opt_result)
|
196
292
|
|
197
293
|
except KeyboardInterrupt:
|
198
294
|
logger.info("Terminating training gracefully after the current iteration.")
|
199
295
|
|
200
296
|
# outer epoch loop
|
201
297
|
init_iter += 1
|
298
|
+
callbacks_end_epoch = [
|
299
|
+
callback
|
300
|
+
for callback in callbacks
|
301
|
+
if callback.call_end_epoch and not callback.call_during_eval
|
302
|
+
]
|
303
|
+
callbacks_end_epoch_eval = [
|
304
|
+
callback
|
305
|
+
for callback in callbacks
|
306
|
+
if callback.call_end_epoch and callback.call_during_eval
|
307
|
+
]
|
202
308
|
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
|
203
309
|
try:
|
204
310
|
# in case there is not data needed by the model
|
205
311
|
# this is the case, for example, of quantum models
|
206
312
|
# which do not have classical input data (e.g. chemistry)
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
313
|
+
loss, metrics = optimize_step(
|
314
|
+
model=model,
|
315
|
+
optimizer=optimizer,
|
316
|
+
loss_fn=loss_fn,
|
317
|
+
xs=None if dataloader is None else next(dl_iter), # type: ignore[arg-type]
|
318
|
+
device=device,
|
319
|
+
dtype=data_dtype,
|
320
|
+
)
|
321
|
+
if isinstance(loss, Tensor):
|
216
322
|
loss = loss.item()
|
323
|
+
opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics)
|
324
|
+
run_callbacks(callbacks_end_epoch, opt_result)
|
217
325
|
|
218
|
-
elif isinstance(dataloader, (DictDataLoader, DataLoader)):
|
219
|
-
loss, metrics = optimize_step(
|
220
|
-
model=model,
|
221
|
-
optimizer=optimizer,
|
222
|
-
loss_fn=loss_fn,
|
223
|
-
xs=next(dl_iter), # type: ignore[arg-type]
|
224
|
-
device=device,
|
225
|
-
dtype=data_dtype,
|
226
|
-
)
|
227
|
-
|
228
|
-
else:
|
229
|
-
raise NotImplementedError(
|
230
|
-
f"Unsupported dataloader type: {type(dataloader)}. "
|
231
|
-
"You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
|
232
|
-
)
|
233
|
-
|
234
|
-
if iteration % config.print_every == 0 and config.verbose:
|
235
|
-
# Note that the loss returned by optimize_step
|
236
|
-
# is the value before doing the training step
|
237
|
-
# which is printed accordingly by the previous iteration number
|
238
|
-
print_metrics(loss, metrics, iteration - 1)
|
239
|
-
|
240
|
-
if iteration % config.write_every == 0:
|
241
|
-
write_tracker(
|
242
|
-
writer, loss, metrics, iteration, tracking_tool=config.tracking_tool
|
243
|
-
)
|
244
|
-
|
245
|
-
if iteration % config.plot_every == 0:
|
246
|
-
plot_tracker(
|
247
|
-
writer,
|
248
|
-
model,
|
249
|
-
iteration,
|
250
|
-
config.plotting_functions,
|
251
|
-
tracking_tool=config.tracking_tool,
|
252
|
-
)
|
253
326
|
if perform_val:
|
254
327
|
if iteration % config.val_every == 0:
|
255
|
-
|
256
|
-
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
257
|
-
val_loss, *_ = loss_fn(model, xs_to_device)
|
328
|
+
val_loss, *_ = next_loss_iter(dl_iter_val)
|
258
329
|
if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc]
|
259
330
|
best_val_loss = val_loss
|
260
|
-
if config.folder and config.checkpoint_best_only:
|
261
|
-
write_checkpoint(config.folder, model, optimizer, iteration="best")
|
262
331
|
metrics["val_loss"] = val_loss
|
263
|
-
|
264
|
-
writer, loss, metrics, iteration, tracking_tool=config.tracking_tool
|
265
|
-
)
|
332
|
+
opt_result.metrics = metrics
|
266
333
|
|
267
|
-
|
268
|
-
if iteration % config.checkpoint_every == 0 and not config.checkpoint_best_only:
|
269
|
-
write_checkpoint(config.folder, model, optimizer, iteration)
|
334
|
+
run_callbacks(callbacks_end_epoch_eval, opt_result)
|
270
335
|
|
271
336
|
except KeyboardInterrupt:
|
272
337
|
logger.info("Terminating training gracefully after the current iteration.")
|
@@ -275,21 +340,16 @@ def train(
|
|
275
340
|
# Handling printing the last training loss
|
276
341
|
# as optimize_step does not give the loss value at the last iteration
|
277
342
|
try:
|
278
|
-
|
279
|
-
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
280
|
-
loss, metrics, *_ = loss_fn(model, xs_to_device)
|
281
|
-
if dataloader is None:
|
282
|
-
loss = loss.item()
|
343
|
+
loss, metrics, *_ = next_loss_iter(dl_iter)
|
283
344
|
if iteration % config.print_every == 0 and config.verbose:
|
284
345
|
print_metrics(loss, metrics, iteration)
|
285
346
|
|
286
347
|
except KeyboardInterrupt:
|
287
348
|
logger.info("Terminating training gracefully after the current iteration.")
|
288
349
|
|
289
|
-
# Final checkpointing and writing
|
290
|
-
|
291
|
-
|
292
|
-
write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool)
|
350
|
+
# Final callbacks, by default checkpointing and writing
|
351
|
+
callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt]
|
352
|
+
run_callbacks(callbacks_after_opt, opt_result)
|
293
353
|
|
294
354
|
# writing hyperparameters
|
295
355
|
if config.hyperparams:
|
@@ -6,20 +6,14 @@ from typing import Callable
|
|
6
6
|
|
7
7
|
import nevergrad as ng
|
8
8
|
from nevergrad.optimization.base import Optimizer as NGOptimizer
|
9
|
-
from rich.progress import
|
10
|
-
BarColumn,
|
11
|
-
Progress,
|
12
|
-
TaskProgressColumn,
|
13
|
-
TextColumn,
|
14
|
-
TimeRemainingColumn,
|
15
|
-
)
|
9
|
+
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
16
10
|
from torch import Tensor
|
17
11
|
from torch.nn import Module
|
18
12
|
from torch.utils.data import DataLoader
|
19
13
|
from torch.utils.tensorboard import SummaryWriter
|
20
14
|
|
21
|
-
from qadence.ml_tools.config import TrainConfig
|
22
|
-
from qadence.ml_tools.data import DictDataLoader
|
15
|
+
from qadence.ml_tools.config import Callback, TrainConfig
|
16
|
+
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
|
23
17
|
from qadence.ml_tools.parameters import get_parameters, set_parameters
|
24
18
|
from qadence.ml_tools.printing import (
|
25
19
|
log_model_tracker,
|
@@ -92,6 +86,12 @@ def train(
|
|
92
86
|
params = get_parameters(model).detach().numpy()
|
93
87
|
ng_params = ng.p.Array(init=params)
|
94
88
|
|
89
|
+
if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))):
|
90
|
+
raise NotImplementedError(
|
91
|
+
f"Unsupported dataloader type: {type(dataloader)}. "
|
92
|
+
"You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
|
93
|
+
)
|
94
|
+
|
95
95
|
# serial training
|
96
96
|
# TODO: Add a parallelization using the num_workers argument in Nevergrad
|
97
97
|
progress = Progress(
|
@@ -100,38 +100,85 @@ def train(
|
|
100
100
|
TaskProgressColumn(),
|
101
101
|
TimeRemainingColumn(elapsed_when_finished=True),
|
102
102
|
)
|
103
|
-
with progress:
|
104
|
-
dl_iter = iter(dataloader) if dataloader is not None else None
|
105
|
-
|
106
|
-
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
|
107
|
-
if dataloader is None:
|
108
|
-
loss, metrics, ng_params = _update_parameters(None, ng_params)
|
109
|
-
|
110
|
-
elif isinstance(dataloader, (DictDataLoader, DataLoader)):
|
111
|
-
data = next(dl_iter) # type: ignore[arg-type]
|
112
|
-
loss, metrics, ng_params = _update_parameters(data, ng_params)
|
113
|
-
|
114
|
-
else:
|
115
|
-
raise NotImplementedError("Unsupported dataloader type!")
|
116
|
-
|
117
|
-
if iteration % config.print_every == 0 and config.verbose:
|
118
|
-
print_metrics(loss, metrics, iteration)
|
119
103
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
104
|
+
# populate callbacks with already available internal functions
|
105
|
+
# printing, writing and plotting
|
106
|
+
callbacks = config.callbacks
|
107
|
+
|
108
|
+
# printing
|
109
|
+
if config.verbose and config.print_every > 0:
|
110
|
+
callbacks += [
|
111
|
+
Callback(
|
112
|
+
lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration),
|
113
|
+
called_every=config.print_every,
|
114
|
+
)
|
115
|
+
]
|
116
|
+
|
117
|
+
# writing metrics
|
118
|
+
if config.write_every > 0:
|
119
|
+
callbacks += [
|
120
|
+
Callback(
|
121
|
+
lambda opt_res: write_tracker(
|
122
|
+
writer,
|
123
|
+
opt_res.loss,
|
124
|
+
opt_res.metrics,
|
125
|
+
opt_res.iteration,
|
126
|
+
tracking_tool=config.tracking_tool,
|
127
|
+
),
|
128
|
+
called_every=config.write_every,
|
129
|
+
call_after_opt=True,
|
130
|
+
)
|
131
|
+
]
|
132
|
+
|
133
|
+
# plot tracker
|
134
|
+
if config.plot_every > 0:
|
135
|
+
callbacks += [
|
136
|
+
Callback(
|
137
|
+
lambda opt_res: plot_tracker(
|
125
138
|
writer,
|
126
|
-
model,
|
127
|
-
iteration,
|
139
|
+
opt_res.model,
|
140
|
+
opt_res.iteration,
|
128
141
|
config.plotting_functions,
|
129
142
|
tracking_tool=config.tracking_tool,
|
130
|
-
)
|
143
|
+
),
|
144
|
+
called_every=config.plot_every,
|
145
|
+
)
|
146
|
+
]
|
147
|
+
|
148
|
+
# checkpointing
|
149
|
+
if config.folder and config.checkpoint_every > 0:
|
150
|
+
callbacks += [
|
151
|
+
Callback(
|
152
|
+
lambda opt_res: write_checkpoint(
|
153
|
+
config.folder, # type: ignore[arg-type]
|
154
|
+
opt_res.model,
|
155
|
+
opt_res.optimizer,
|
156
|
+
opt_res.iteration,
|
157
|
+
),
|
158
|
+
called_every=config.checkpoint_every,
|
159
|
+
call_after_opt=True,
|
160
|
+
)
|
161
|
+
]
|
162
|
+
|
163
|
+
def run_callbacks(callback_iterable: list[Callback], opt_res: OptimizeResult) -> None:
|
164
|
+
for callback in callback_iterable:
|
165
|
+
callback(opt_res)
|
166
|
+
|
167
|
+
callbacks_end_opt = [
|
168
|
+
callback
|
169
|
+
for callback in callbacks
|
170
|
+
if callback.call_end_epoch and not callback.call_during_eval
|
171
|
+
]
|
172
|
+
|
173
|
+
with progress:
|
174
|
+
dl_iter = iter(dataloader) if dataloader is not None else None
|
131
175
|
|
132
|
-
|
133
|
-
|
134
|
-
|
176
|
+
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
|
177
|
+
loss, metrics, ng_params = _update_parameters(
|
178
|
+
None if dataloader is None else next(dl_iter), ng_params # type: ignore[arg-type]
|
179
|
+
)
|
180
|
+
opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics)
|
181
|
+
run_callbacks(callbacks_end_opt, opt_result)
|
135
182
|
|
136
183
|
if iteration >= init_iter + config.max_iter:
|
137
184
|
break
|
@@ -143,10 +190,9 @@ def train(
|
|
143
190
|
if config.log_model:
|
144
191
|
log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool)
|
145
192
|
|
146
|
-
# Final
|
147
|
-
if
|
148
|
-
|
149
|
-
write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool)
|
193
|
+
# Final callbacks
|
194
|
+
callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt]
|
195
|
+
run_callbacks(callbacks_after_opt, opt_result)
|
150
196
|
|
151
197
|
# close tracker
|
152
198
|
if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
|
qadence/model.py
CHANGED
@@ -353,11 +353,11 @@ class QuantumModel(nn.Module):
|
|
353
353
|
"""
|
354
354
|
raise NotImplementedError("The overlap method is not implemented for this model.")
|
355
355
|
|
356
|
-
def _to_dict(self, save_params: bool =
|
356
|
+
def _to_dict(self, save_params: bool = True) -> dict[str, Any]:
|
357
357
|
"""Convert QuantumModel to a dictionary for serialization.
|
358
358
|
|
359
359
|
Arguments:
|
360
|
-
save_params:
|
360
|
+
save_params: Save parameters. Defaults to True.
|
361
361
|
|
362
362
|
Returns:
|
363
363
|
The dictionary
|
@@ -382,7 +382,7 @@ class QuantumModel(nn.Module):
|
|
382
382
|
}
|
383
383
|
param_dict_conv = {}
|
384
384
|
if save_params:
|
385
|
-
param_dict_conv = {name: param
|
385
|
+
param_dict_conv = {name: param for name, param in self._params.items()}
|
386
386
|
d = {self.__class__.__name__: d, "param_dict": param_dict_conv}
|
387
387
|
logger.debug(f"{self.__class__.__name__} serialized to {d}.")
|
388
388
|
except Exception as e:
|
@@ -432,6 +432,50 @@ class QuantumModel(nn.Module):
|
|
432
432
|
|
433
433
|
return qm
|
434
434
|
|
435
|
+
def load_params_from_dict(self, d: dict, strict: bool = True) -> None:
|
436
|
+
"""Copy parameters from dictionary into this QuantumModel.
|
437
|
+
|
438
|
+
Unlike :meth:`~qadence.QuantumModel.from_dict`, this method does not create a new
|
439
|
+
QuantumModel instance, but rather loads the parameters into the same QuantumModel.
|
440
|
+
The behaviour of this method is similar to :meth:`~torch.nn.Module.load_state_dict`.
|
441
|
+
|
442
|
+
The dictionary is assumed to have the format as saved via
|
443
|
+
:meth:`~qadence.QuantumModel.to_dict`
|
444
|
+
|
445
|
+
Args:
|
446
|
+
d (dict): The dictionary
|
447
|
+
strict (bool, optional):
|
448
|
+
Whether to strictly enforce that the parameter keys in the dictionary and
|
449
|
+
in the model match exactly. Default: ``True``.
|
450
|
+
"""
|
451
|
+
param_dict = d["param_dict"]
|
452
|
+
missing_keys = set(self._params.keys()) - set(param_dict.keys())
|
453
|
+
unexpected_keys = set(param_dict.keys()) - set(self._params.keys())
|
454
|
+
|
455
|
+
if strict:
|
456
|
+
error_msgs = []
|
457
|
+
if len(unexpected_keys) > 0:
|
458
|
+
error_msgs.append(f"Unexpected key(s) in dictionary: {unexpected_keys}")
|
459
|
+
if len(missing_keys) > 0:
|
460
|
+
error_msgs.append(f"Missing key(s) in dictionary: {missing_keys}")
|
461
|
+
if len(error_msgs) > 0:
|
462
|
+
errors_string = "\n\t".join(error_msgs)
|
463
|
+
raise RuntimeError(
|
464
|
+
f"Error(s) loading the parameter dictionary due to: \n\t{errors_string}\n"
|
465
|
+
"This error was thrown because the `strict` argument is set `True`."
|
466
|
+
"If you don't need the parameter keys of the dictionary to exactly match "
|
467
|
+
"the model parameters, set `strict=False`."
|
468
|
+
)
|
469
|
+
|
470
|
+
for n, param in param_dict.items():
|
471
|
+
try:
|
472
|
+
with torch.no_grad():
|
473
|
+
self._params[n].copy_(
|
474
|
+
torch.nn.Parameter(param, requires_grad=param.requires_grad)
|
475
|
+
)
|
476
|
+
except Exception as e:
|
477
|
+
logger.warning(f"Unable to load parameter {n} from dictionary due to {e}.")
|
478
|
+
|
435
479
|
def save(
|
436
480
|
self, folder: str | Path, file_name: str = "quantum_model.pt", save_params: bool = True
|
437
481
|
) -> None:
|
qadence/types.py
CHANGED
@@ -156,9 +156,9 @@ class ReuploadScaling(StrEnum):
|
|
156
156
|
class MultivariateStrategy(StrEnum):
|
157
157
|
"""Multivariate strategy for feature maps."""
|
158
158
|
|
159
|
-
PARALLEL = "
|
159
|
+
PARALLEL = "Parallel"
|
160
160
|
"""Parallel strategy."""
|
161
|
-
SERIES = "
|
161
|
+
SERIES = "Series"
|
162
162
|
"""Serial strategy."""
|
163
163
|
|
164
164
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: qadence
|
3
|
-
Version: 1.7.
|
3
|
+
Version: 1.7.6
|
4
4
|
Summary: Pasqal interface for circuit-based quantum computing SDKs
|
5
5
|
Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>, Giorgio Tosti Balducci <giorgio.tosti-balducci@pasqal.com>
|
6
6
|
License: Apache 2.0
|
@@ -22,7 +22,7 @@ Requires-Dist: matplotlib
|
|
22
22
|
Requires-Dist: nevergrad
|
23
23
|
Requires-Dist: numpy
|
24
24
|
Requires-Dist: openfermion
|
25
|
-
Requires-Dist: pyqtorch==1.
|
25
|
+
Requires-Dist: pyqtorch==1.4.4
|
26
26
|
Requires-Dist: pyyaml
|
27
27
|
Requires-Dist: rich
|
28
28
|
Requires-Dist: scipy
|
@@ -45,7 +45,7 @@ Requires-Dist: nvidia-pyindex; extra == 'dlprof'
|
|
45
45
|
Provides-Extra: horqrux
|
46
46
|
Requires-Dist: einops; extra == 'horqrux'
|
47
47
|
Requires-Dist: flax; extra == 'horqrux'
|
48
|
-
Requires-Dist: horqrux==0.6.
|
48
|
+
Requires-Dist: horqrux==0.6.2; extra == 'horqrux'
|
49
49
|
Requires-Dist: jax; extra == 'horqrux'
|
50
50
|
Requires-Dist: jaxopt; extra == 'horqrux'
|
51
51
|
Requires-Dist: optax; extra == 'horqrux'
|
@@ -57,7 +57,7 @@ Requires-Dist: mlflow; extra == 'mlflow'
|
|
57
57
|
Provides-Extra: protocols
|
58
58
|
Requires-Dist: qadence-protocols; extra == 'protocols'
|
59
59
|
Provides-Extra: pulser
|
60
|
-
Requires-Dist: pasqal-cloud==0.11.
|
60
|
+
Requires-Dist: pasqal-cloud==0.11.3; extra == 'pulser'
|
61
61
|
Requires-Dist: pulser-core==0.19.0; extra == 'pulser'
|
62
62
|
Requires-Dist: pulser-simulation==0.19.0; extra == 'pulser'
|
63
63
|
Provides-Extra: visualization
|