qadence 1.7.4__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,7 @@ import os
5
5
  from dataclasses import dataclass, field, fields
6
6
  from logging import getLogger
7
7
  from pathlib import Path
8
- from typing import Callable, Type
8
+ from typing import Any, Callable, Type
9
9
  from uuid import uuid4
10
10
 
11
11
  from sympy import Basic
@@ -13,6 +13,7 @@ from torch import Tensor
13
13
 
14
14
  from qadence.blocks.analog import AnalogBlock
15
15
  from qadence.blocks.primitive import ParametricBlock
16
+ from qadence.ml_tools.data import OptimizeResult
16
17
  from qadence.operations import RX, AnalogRX
17
18
  from qadence.parameters import Parameter
18
19
  from qadence.types import (
@@ -27,6 +28,84 @@ from qadence.types import (
27
28
 
28
29
  logger = getLogger(__file__)
29
30
 
31
+ CallbackFunction = Callable[[OptimizeResult], None]
32
+ CallbackConditionFunction = Callable[[OptimizeResult], bool]
33
+
34
+
35
+ class Callback:
36
+ """Callback functions are calling in train functions.
37
+
38
+ Each callback function should take at least as first input
39
+ an OptimizeResult instance.
40
+
41
+ Attributes:
42
+ callback (CallbackFunction): Callback function accepting an
43
+ OptimizeResult as first argument.
44
+ callback_condition (CallbackConditionFunction | None, optional): Function that
45
+ conditions the call to callback. Defaults to None.
46
+ called_every (int, optional): Callback to be called each `called_every` epoch.
47
+ Defaults to 1.
48
+ If callback_condition is None, we set
49
+ callback_condition to returns True when iteration % every == 0.
50
+ call_before_opt (bool, optional): If true, callback is applied before training.
51
+ Defaults to False.
52
+ call_end_epoch (bool, optional): If true, callback is applied during training,
53
+ after an epoch is performed. Defaults to True.
54
+ call_after_opt (bool, optional): If true, callback is applied after training.
55
+ Defaults to False.
56
+ call_during_eval (bool, optional): If true, callback is applied during evaluation.
57
+ Defaults to False.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ callback: CallbackFunction,
63
+ callback_condition: CallbackConditionFunction | None = None,
64
+ called_every: int = 1,
65
+ call_before_opt: bool = False,
66
+ call_end_epoch: bool = True,
67
+ call_after_opt: bool = False,
68
+ call_during_eval: bool = False,
69
+ ) -> None:
70
+ """Initialized Callback.
71
+
72
+ Args:
73
+ callback (CallbackFunction): Callback function accepting an
74
+ OptimizeResult as ifrst argument.
75
+ callback_condition (CallbackConditionFunction | None, optional): Function that
76
+ conditions the call to callback. Defaults to None.
77
+ called_every (int, optional): Callback to be called each `called_every` epoch.
78
+ Defaults to 1.
79
+ If callback_condition is None, we set
80
+ callback_condition to returns True when iteration % every == 0.
81
+ call_before_opt (bool, optional): If true, callback is applied before training.
82
+ Defaults to False.
83
+ call_end_epoch (bool, optional): If true, callback is applied during training,
84
+ after an epoch is performed. Defaults to True.
85
+ call_after_opt (bool, optional): If true, callback is applied after training.
86
+ Defaults to False.
87
+ call_during_eval (bool, optional): If true, callback is applied during evaluation.
88
+ Defaults to False.
89
+ """
90
+ self.callback = callback
91
+ self.call_before_opt = call_before_opt
92
+ self.call_end_epoch = call_end_epoch
93
+ self.call_after_opt = call_after_opt
94
+ self.call_during_eval = call_during_eval
95
+
96
+ if called_every <= 0:
97
+ raise ValueError("Please provide a strictly positive `called_every` argument.")
98
+ self.called_every = called_every
99
+
100
+ if callback_condition is None:
101
+ self.callback_condition = lambda opt_result: True
102
+ else:
103
+ self.callback_condition = callback_condition
104
+
105
+ def __call__(self, opt_result: OptimizeResult) -> Any:
106
+ if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result):
107
+ return self.callback(opt_result)
108
+
30
109
 
31
110
  @dataclass
32
111
  class TrainConfig:
@@ -45,13 +124,27 @@ class TrainConfig:
45
124
  max_iter: int = 10000
46
125
  """Number of training iterations."""
47
126
  print_every: int = 1000
48
- """Print loss/metrics."""
127
+ """Print loss/metrics.
128
+
129
+ Set to 0 to disable
130
+ """
49
131
  write_every: int = 50
50
- """Write loss and metrics with the tracking tool."""
132
+ """Write loss and metrics with the tracking tool.
133
+
134
+ Set to 0 to disable
135
+ """
51
136
  checkpoint_every: int = 5000
52
- """Write model/optimizer checkpoint."""
137
+ """Write model/optimizer checkpoint.
138
+
139
+ Set to 0 to disable
140
+ """
53
141
  plot_every: int = 5000
54
- """Write figures."""
142
+ """Write figures.
143
+
144
+ Set to 0 to disable
145
+ """
146
+ callbacks: list[Callback] = field(default_factory=lambda: list())
147
+ """List of callbacks."""
55
148
  log_model: bool = False
56
149
  """Logs a serialised version of the model."""
57
150
  folder: Path | None = None
@@ -234,13 +327,13 @@ class FeatureMapConfig:
234
327
 
235
328
  multivariate_strategy: MultivariateStrategy = MultivariateStrategy.PARALLEL
236
329
  """
237
- The encoding strategy in case of multi-variate function.
330
+ The encoding strategy in case of multi-variate function.
238
331
 
239
332
  Takes qadence.MultivariateStrategy.
240
333
  If PARALLEL, the features are encoded in one block of rotation gates
241
- with each feature given an equal number of qubits.
242
- If SERIES, the features are encoded sequentially, with an ansatz block
243
- between. PARALLEL is allowed only for DIGITAL `feature_map_strategy`.
334
+ with the register being split in sub-registers for each feature.
335
+ If SERIES, the features are encoded sequentially using the full register for each feature, with
336
+ an ansatz block between them. PARALLEL is allowed only for DIGITAL `feature_map_strategy`.
244
337
  """
245
338
 
246
339
  feature_map_strategy: Strategy = Strategy.DIGITAL
@@ -261,7 +354,7 @@ class FeatureMapConfig:
261
354
  account the domain of the feature-encoding function.
262
355
  Defaults to `None` and thus, the feature map is not trainable.
263
356
  Note that this is separate from the name of the parameter.
264
- The user can provide a single prefix for all features, and they will be appended
357
+ The user can provide a single prefix for all features, and it will be appended
265
358
  by appropriate feature name automatically.
266
359
  """
267
360
 
@@ -269,7 +362,7 @@ class FeatureMapConfig:
269
362
  """
270
363
  Number of feature map layers repeated in the data reuploading step.
271
364
 
272
- If all are to be repeated the same number of times, then can give a single
365
+ If all features are to be repeated the same number of times, then can give a single
273
366
  `int`. For different number of repetitions for each feature, provide a dict
274
367
  of (str, int) where the key is the name of the variable and the value is the
275
368
  number of repetitions for that feature.
@@ -300,8 +393,8 @@ class FeatureMapConfig:
300
393
  if self.multivariate_strategy == MultivariateStrategy.PARALLEL and self.num_features > 1:
301
394
  assert (
302
395
  self.feature_map_strategy == Strategy.DIGITAL
303
- ), "For `parallel` encoding of multiple features, the `feature_map_strategy` must be \
304
- of `digital` type."
396
+ ), "For parallel encoding of multiple features, the `feature_map_strategy` must be \
397
+ of `Strategy.DIGITAL`."
305
398
 
306
399
  if self.operation is None:
307
400
  if self.feature_map_strategy == Strategy.DIGITAL:
@@ -314,8 +407,8 @@ class FeatureMapConfig:
314
407
  if isinstance(self.operation, AnalogBlock):
315
408
  logger.warning(
316
409
  "The `operation` is of type `AnalogBlock` but the `feature_map_strategy` is\
317
- `digital`. The `feature_map_strategy` will be modified and given operation\
318
- will be used."
410
+ `Strategy.DIGITAL`. The `feature_map_strategy` will be modified and given \
411
+ operation will be used."
319
412
  )
320
413
 
321
414
  self.feature_map_strategy = Strategy.ANALOG
@@ -324,12 +417,25 @@ class FeatureMapConfig:
324
417
  if isinstance(self.operation, ParametricBlock):
325
418
  logger.warning(
326
419
  "The `operation` is a digital gate but the `feature_map_strategy` is\
327
- `analog`. The `feature_map_strategy` will be modified and given operation\
328
- will be used."
420
+ `Strategy.ANALOG`. The `feature_map_strategy` will be modified and given\
421
+ operation will be used."
329
422
  )
330
423
 
331
424
  self.feature_map_strategy = Strategy.DIGITAL
332
425
 
426
+ elif self.feature_map_strategy == Strategy.RYDBERG:
427
+ if self.operation is not None:
428
+ logger.warning(
429
+ f"feature_map_strategy is `Strategy.RYDBERG` which does not take any\
430
+ operation. But an operation {self.operation} is provided. The \
431
+ `feature_map_strategy` will be modified and given operation will be used."
432
+ )
433
+
434
+ if isinstance(self.operation, AnalogBlock):
435
+ self.feature_map_strategy = Strategy.ANALOG
436
+ else:
437
+ self.feature_map_strategy = Strategy.DIGITAL
438
+
333
439
  if self.inputs is not None:
334
440
  assert (
335
441
  len(self.inputs) == self.num_features
@@ -388,16 +494,16 @@ class AnsatzConfig:
388
494
  ansatz_type: AnsatzType = AnsatzType.HEA
389
495
  """What type of ansatz.
390
496
 
391
- HEA for Hardware Efficient Ansatz.
392
- IIA for Identity intialized Ansatz.
497
+ `AnsatzType.HEA` for Hardware Efficient Ansatz.
498
+ `AnsatzType.IIA` for Identity intialized Ansatz.
393
499
  """
394
500
 
395
501
  ansatz_strategy: Strategy = Strategy.DIGITAL
396
502
  """Ansatz strategy.
397
503
 
398
- DIGITAL for fully digital ansatz. Required if `ansatz_type` is `iia`.
399
- SDAQC for analog entangling block.
400
- RYDBERG for fully rydberg hea ansatz.
504
+ `Strategy.DIGITAL` for fully digital ansatz. Required if `ansatz_type` is `AnsatzType.IIA`.
505
+ `Strategy.SDAQC` for analog entangling block.
506
+ `Strategy.RYDBERG` for fully rydberg hea ansatz.
401
507
  """
402
508
 
403
509
  strategy_args: dict = field(default_factory=dict)
@@ -406,7 +512,7 @@ class AnsatzConfig:
406
512
 
407
513
  Details about each below.
408
514
 
409
- For DIGITAL strategy, accepts the following:
515
+ For `Strategy.DIGITAL` strategy, accepts the following:
410
516
  periodic (bool): if the qubits should be linked periodically.
411
517
  periodic=False is not supported in emu-c.
412
518
  operations (list): list of operations to cycle through in the
@@ -417,7 +523,7 @@ class AnsatzConfig:
417
523
  will have variational parameters on the rotation angles.
418
524
  Defaults to CNOT
419
525
 
420
- For SDAQC strategy, accepts the following:
526
+ For `Strategy.SDAQC` strategy, accepts the following:
421
527
  operations (list): list of operations to cycle through in the
422
528
  digital single-qubit rotations of each layer.
423
529
  Defaults to [RX, RY, RX] for hea and [RX, RY] for iia.
@@ -425,7 +531,7 @@ class AnsatzConfig:
425
531
  analog entangling layer. Time parameter is considered variational.
426
532
  Defaults to NN interaction.
427
533
 
428
- For RYDBERG strategy, accepts the following:
534
+ For `Strategy.RYDBERG` strategy, accepts the following:
429
535
  addressable_detuning: whether to turn on the trainable semi-local addressing pattern
430
536
  on the detuning (n_i terms in the Hamiltonian).
431
537
  Defaults to True.
@@ -45,7 +45,7 @@ from .models import QNN
45
45
  def _create_support_arrays(
46
46
  num_qubits: int,
47
47
  num_features: int,
48
- multivariate_strategy: str,
48
+ multivariate_strategy: MultivariateStrategy,
49
49
  ) -> list[tuple[int, ...]]:
50
50
  """
51
51
  Create the support arrays for the digital feature map.
@@ -53,8 +53,8 @@ def _create_support_arrays(
53
53
  Args:
54
54
  num_qubits (int): The number of qubits.
55
55
  num_features (int): The number of features.
56
- multivariate_strategy (str): The multivariate encoding strategy.
57
- Either 'series' or 'parallel'.
56
+ multivariate_strategy (MultivariateStrategy): The multivariate encoding strategy.
57
+ Either 'MultivariateStrategy.SERIES' or 'MultivariateStrategy.PARALLEL'.
58
58
 
59
59
  Returns:
60
60
  list[tuple[int, ...]]: The list of support arrays. ith element of the list is the support
@@ -62,23 +62,25 @@ def _create_support_arrays(
62
62
 
63
63
  Raises:
64
64
  ValueError: If the number of features is greater than the number of qubits
65
- with parallel encoding. Not possible to encode these features in parallel.
66
- ValueError: If the multivariate strategy is not 'series' or 'parallel'.
65
+ and the strategy is `MultivariateStrategy.PARALLEL` not possible to assign a support
66
+ array to each feature.
67
+ ValueError: If the multivariate strategy is not `MultivariateStrategy.SERIES` or
68
+ `MultivariateStrategy.PARALLEL`.
67
69
  """
68
- if multivariate_strategy == "series":
70
+ if multivariate_strategy == MultivariateStrategy.SERIES:
69
71
  return [tuple(range(num_qubits)) for i in range(num_features)]
70
- elif multivariate_strategy == "parallel":
72
+ elif multivariate_strategy == MultivariateStrategy.PARALLEL:
71
73
  if num_features <= num_qubits:
72
74
  return [tuple(x.tolist()) for x in np.array_split(np.arange(num_qubits), num_features)]
73
75
  else:
74
76
  raise ValueError(
75
77
  f"Number of features {num_features} must be less than or equal to the number of \
76
- qubits {num_qubits}. if the features are to be encoded is parallely."
78
+ qubits {num_qubits} if the features are to be encoded parallely."
77
79
  )
78
80
  else:
79
81
  raise ValueError(
80
- f"Invalid encoding strategy {multivariate_strategy} provided. Only 'series' or \
81
- 'parallel' are allowed."
82
+ f"Invalid encoding strategy {multivariate_strategy} provided. Only \
83
+ `MultivariateStrategy.SERIES` or `MultivariateStrategy.PARALLEL` are allowed."
82
84
  )
83
85
 
84
86
 
@@ -212,7 +214,8 @@ def _create_digital_fm(
212
214
  list[AbstractBlock]: The list of digital feature map blocks.
213
215
 
214
216
  Raises:
215
- ValueError: If the encoding strategy is invalid. Only 'series' or 'parallel' are allowed.
217
+ ValueError: If the encoding strategy is invalid. Only `MultivariateStrategy.SERIES` or
218
+ `MultivariateStrategy.PARALLEL` are allowed.
216
219
  """
217
220
  if config.multivariate_strategy == MultivariateStrategy.SERIES:
218
221
  fm_blocks = _encode_features_series_digital(register, config)
@@ -220,8 +223,8 @@ def _create_digital_fm(
220
223
  fm_blocks = _encode_features_parallel_digital(register, config)
221
224
  else:
222
225
  raise ValueError(
223
- f"Invalid encoding strategy {config.multivariate_strategy} provided. Only 'series' or \
224
- 'parallel' are allowed."
226
+ f"Invalid encoding strategy {config.multivariate_strategy} provided. Only\
227
+ `MultivariateStrategy.SERIES` or `MultivariateStrategy.PARALLEL` are allowed."
225
228
  )
226
229
 
227
230
  return fm_blocks
@@ -348,7 +351,8 @@ def create_fm_blocks(
348
351
  list[AbstractBlock]: A list of feature map blocks.
349
352
 
350
353
  Raises:
351
- ValueError: If the feature map strategy is not 'digital', 'analog' or 'rydberg'.
354
+ ValueError: If the feature map strategy is not `Strategy.DIGITAL`, `Strategy.ANALOG` or
355
+ `Strategy.RYDBERG`.
352
356
  """
353
357
  if config.feature_map_strategy == Strategy.DIGITAL:
354
358
  return _create_digital_fm(register=register, config=config)
@@ -359,7 +363,7 @@ def create_fm_blocks(
359
363
  else:
360
364
  raise NotImplementedError(
361
365
  f"Feature map not implemented for strategy {config.feature_map_strategy}. \
362
- Only 'digital', 'analog' or 'rydberg' allowed."
366
+ Only `Strategy.DIGITAL`, `Strategy.ANALOG` or `Strategy.RYDBERG` allowed."
363
367
  )
364
368
 
365
369
 
@@ -463,7 +467,8 @@ def _create_iia(
463
467
  AbstractBlock: The Identity Initialized Ansatz.
464
468
 
465
469
  Raises:
466
- ValueError: If the ansatz strategy is not supported. Only 'digital' and 'sdaqc' are allowed.
470
+ ValueError: If the ansatz strategy is not supported. Only `Strategy.DIGITAL` and
471
+ `Strategy.SDAQC` are allowed.
467
472
  """
468
473
  if config.ansatz_strategy == Strategy.DIGITAL:
469
474
  return _create_iia_digital(num_qubits=num_qubits, config=config)
@@ -471,8 +476,8 @@ def _create_iia(
471
476
  return _create_iia_sdaqc(num_qubits=num_qubits, config=config)
472
477
  else:
473
478
  raise ValueError(
474
- f"Invalid ansatz strategy {config.ansatz_strategy} provided. Only 'digital', 'sdaqc', \
475
- allowed for IIA."
479
+ f"Invalid ansatz strategy {config.ansatz_strategy} provided. Only `Strategy.DIGITAL` \
480
+ and `Strategy.SDAQC` allowed for IIA."
476
481
  )
477
482
 
478
483
 
@@ -487,7 +492,7 @@ def _create_hea_digital(num_qubits: int, config: AnsatzConfig) -> AbstractBlock:
487
492
  Returns:
488
493
  AbstractBlock: The Digital Hardware Efficient Ansatz.
489
494
  """
490
- operations = config.strategy_args.get("rotations", [RX, RY, RX])
495
+ operations = config.strategy_args.get("operations", [RX, RY, RX])
491
496
  entangler = config.strategy_args.get("entangler", CNOT)
492
497
  periodic = config.strategy_args.get("periodic", False)
493
498
 
@@ -512,7 +517,7 @@ def _create_hea_sdaqc(num_qubits: int, config: AnsatzConfig) -> AbstractBlock:
512
517
  Returns:
513
518
  AbstractBlock: The SDAQC Hardware Efficient Ansatz.
514
519
  """
515
- operations = config.strategy_args.get("rotations", [RX, RY, RX])
520
+ operations = config.strategy_args.get("operations", [RX, RY, RX])
516
521
  entangler = config.strategy_args.get(
517
522
  "entangler", hamiltonian_factory(num_qubits, interaction=Interaction.NN)
518
523
  )
@@ -556,7 +561,7 @@ def _create_hea_rydberg(
556
561
  )
557
562
 
558
563
 
559
- def _create_hea_ansatz(
564
+ def _create_hea(
560
565
  register: int | Register,
561
566
  config: AnsatzConfig,
562
567
  ) -> AbstractBlock:
@@ -571,7 +576,8 @@ def _create_hea_ansatz(
571
576
  AbstractBlock: The hardware efficient ansatz block.
572
577
 
573
578
  Raises:
574
- ValueError: If the ansatz strategy is not 'digital', 'sdaqc', or 'rydberg'.
579
+ ValueError: If the ansatz strategy is not `Strategy.DIGITAL`, `Strategy.SDAQC`, or
580
+ `Strategy.RYDBERG`.
575
581
  """
576
582
  num_qubits = register if isinstance(register, int) else register.n_qubits
577
583
 
@@ -583,8 +589,8 @@ def _create_hea_ansatz(
583
589
  return _create_hea_rydberg(register=register, config=config)
584
590
  else:
585
591
  raise ValueError(
586
- f"Invalid ansatz strategy {config.ansatz_strategy} provided. Only 'digital', 'sdaqc', \
587
- and 'rydberg' allowed"
592
+ f"Invalid ansatz strategy {config.ansatz_strategy} provided. Only `Strategy.DIGITAL`, \
593
+ `Strategy.SDAQC`, and `Strategy.RYDBERG` allowed"
588
594
  )
589
595
 
590
596
 
@@ -610,11 +616,11 @@ def create_ansatz(
610
616
  if config.ansatz_type == AnsatzType.IIA:
611
617
  return _create_iia(num_qubits=num_qubits, config=config)
612
618
  elif config.ansatz_type == AnsatzType.HEA:
613
- return _create_hea_ansatz(register=register, config=config)
619
+ return _create_hea(register=register, config=config)
614
620
  else:
615
621
  raise NotImplementedError(
616
- f"Ansatz of type {config.ansatz_type} not implemented yet. Only 'hea' and\
617
- 'iia' available."
622
+ f"Ansatz of type {config.ansatz_type} not implemented yet. Only `AnsatzType.HEA` and\
623
+ `AnsatzType.IIA` available."
618
624
  )
619
625
 
620
626
 
@@ -651,7 +657,7 @@ def load_observable_transformations(config: ObservableConfig) -> tuple[Parameter
651
657
  config (ObservableConfig): Observable configuration.
652
658
 
653
659
  Returns:
654
- tuple[float, float]: The observable shifting and scaling factors.
660
+ tuple[Parameter, Parameter]: The observable shifting and scaling factors.
655
661
  """
656
662
  shift = config.shift
657
663
  scale = config.scale
@@ -665,9 +671,9 @@ def load_observable_transformations(config: ObservableConfig) -> tuple[Parameter
665
671
 
666
672
 
667
673
  ObservableTransformMap = {
668
- ObservableTransform.RANGE: lambda detuning, scale, shift: (shift, shift - scale)
669
- if detuning is N
670
- else (0.5 * (shift - scale), 0.5 * (scale + shift)),
674
+ ObservableTransform.RANGE: lambda detuning, scale, shift: (
675
+ (shift, shift - scale) if detuning is N else (0.5 * (shift - scale), 0.5 * (scale + shift))
676
+ ),
671
677
  ObservableTransform.SCALE: lambda _, scale, shift: (scale, shift),
672
678
  }
673
679
 
@@ -718,8 +724,8 @@ def create_observable(
718
724
  """
719
725
  if transformation_type == ObservableTransform.RANGE:
720
726
  scale, shift = ObservableTransformMap[transformation_type](detuning, scale, shift) # type: ignore[index]
721
- shifting_term = shift * _global_identity(register) # type: ignore[operator]
722
- detuning_hamiltonian = scale * hamiltonian_factory( # type: ignore[operator]
727
+ shifting_term: AbstractBlock = shift * _global_identity(register) # type: ignore[operator]
728
+ detuning_hamiltonian: AbstractBlock = scale * hamiltonian_factory( # type: ignore[operator]
723
729
  register=register,
724
730
  detuning=detuning,
725
731
  )
qadence/ml_tools/data.py CHANGED
@@ -1,15 +1,41 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, field
4
4
  from functools import singledispatch
5
5
  from itertools import cycle
6
6
  from typing import Any, Iterator
7
7
 
8
+ from nevergrad.optimization.base import Optimizer as NGOptimizer
8
9
  from torch import Tensor
9
10
  from torch import device as torch_device
11
+ from torch.nn import Module
12
+ from torch.optim import Optimizer
10
13
  from torch.utils.data import DataLoader, IterableDataset, TensorDataset
11
14
 
12
15
 
16
+ @dataclass
17
+ class OptimizeResult:
18
+ """OptimizeResult stores many optimization intermediate values.
19
+
20
+ We store at a current iteration,
21
+ the model, optimizer, loss values, metrics. An extra dict
22
+ can be used for saving other information to be used for callbacks.
23
+ """
24
+
25
+ iteration: int
26
+ """Current iteration number."""
27
+ model: Module
28
+ """Model at iteration."""
29
+ optimizer: Optimizer | NGOptimizer
30
+ """Optimizer at iteration."""
31
+ loss: Tensor | float | None = None
32
+ """Loss value."""
33
+ metrics: dict = field(default_factory=lambda: dict())
34
+ """Metrics that can be saved during training."""
35
+ extra: dict = field(default_factory=lambda: dict())
36
+ """Extra dict for saving anything else to be used in callbacks."""
37
+
38
+
13
39
  @dataclass
14
40
  class DictDataLoader:
15
41
  """This class only holds a dictionary of `DataLoader`s and samples from them."""
@@ -29,10 +29,11 @@ def optimize_step(
29
29
  xs (dict | list | torch.Tensor | None): the input data. If None it means
30
30
  that the given model does not require any input data
31
31
  device (torch.device): A target device to run computation on.
32
+ dtype (torch.dtype): Data type for xs conversion.
32
33
 
33
34
  Returns:
34
- tuple: tuple containing the model, the optimizer, a dictionary with
35
- the collected metrics and the compute value loss
35
+ tuple: tuple containing the computed loss value, and a dictionary with
36
+ the collected metrics.
36
37
  """
37
38
 
38
39
  loss, metrics = None, {}
@@ -72,7 +72,8 @@ def write_checkpoint(
72
72
  device = None
73
73
  try:
74
74
  # We extract the device from the pyqtorch native circuit
75
- device = str(model.device).split(":")[0] # in case of using several CUDA devices
75
+ device = model.device if isinstance(QuantumModel, QNN) else next(model.parameters()).device
76
+ device = str(device).split(":")[0] # in case of using several CUDA devices
76
77
  except Exception as e:
77
78
  msg = (
78
79
  f"Unable to identify in which device the QuantumModel is stored due to {e}."
@@ -132,7 +133,7 @@ def load_model(
132
133
  try:
133
134
  iteration, model_dict = torch.load(folder / model_ckpt_name, *args, **kwargs)
134
135
  if isinstance(model, (QuantumModel, QNN)):
135
- model._from_dict(model_dict, as_torch=True)
136
+ model.load_params_from_dict(model_dict)
136
137
  elif isinstance(model, Module):
137
138
  model.load_state_dict(model_dict, strict=True)
138
139
  # Load model to a specific gpu device if specified