qadence 1.11.1__py3-none-any.whl → 1.11.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. qadence/backend.py +33 -10
  2. qadence/backends/agpsr_utils.py +96 -0
  3. qadence/backends/api.py +8 -1
  4. qadence/backends/horqrux/backend.py +24 -10
  5. qadence/backends/horqrux/config.py +17 -1
  6. qadence/backends/horqrux/convert_ops.py +20 -97
  7. qadence/backends/jax_utils.py +5 -2
  8. qadence/backends/{gpsr.py → parameter_shift_rules.py} +48 -30
  9. qadence/backends/pulser/backend.py +16 -9
  10. qadence/backends/pulser/config.py +18 -0
  11. qadence/backends/pyqtorch/backend.py +25 -11
  12. qadence/backends/pyqtorch/config.py +18 -0
  13. qadence/blocks/embedding.py +10 -1
  14. qadence/blocks/primitive.py +2 -3
  15. qadence/blocks/utils.py +33 -24
  16. qadence/engines/differentiable_backend.py +7 -1
  17. qadence/engines/jax/differentiable_backend.py +7 -1
  18. qadence/engines/torch/differentiable_backend.py +12 -9
  19. qadence/engines/torch/differentiable_expectation.py +12 -11
  20. qadence/extensions.py +0 -10
  21. qadence/ml_tools/__init__.py +2 -0
  22. qadence/ml_tools/callbacks/callbackmanager.py +4 -2
  23. qadence/ml_tools/constructors.py +264 -4
  24. qadence/ml_tools/qcnn_model.py +158 -0
  25. qadence/model.py +113 -8
  26. qadence/parameters.py +2 -0
  27. qadence/serialization.py +1 -1
  28. qadence/transpile/__init__.py +3 -2
  29. qadence/transpile/block.py +58 -5
  30. qadence/types.py +2 -4
  31. qadence/utils.py +39 -8
  32. {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/METADATA +22 -11
  33. {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/RECORD +35 -33
  34. qadence-1.11.3.dist-info/licenses/LICENSE +13 -0
  35. qadence-1.11.1.dist-info/licenses/LICENSE +0 -202
  36. {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/WHEEL +0 -0
@@ -28,7 +28,7 @@ from qadence.noise import NoiseHandler
28
28
  from qadence.overlap import overlap_exact
29
29
  from qadence.register import Register
30
30
  from qadence.transpile import transpile
31
- from qadence.types import BackendName, DeviceType, Endianness, Engine, NoiseProtocol
31
+ from qadence.types import BackendName, DeviceType, Endianness, Engine, NoiseProtocol, ParamDictType
32
32
 
33
33
  from .channels import GLOBAL_CHANNEL, LOCAL_CHANNEL
34
34
  from .cloud import get_client
@@ -183,7 +183,7 @@ class Backend(BackendInterface):
183
183
  def run(
184
184
  self,
185
185
  circuit: ConvertedCircuit,
186
- param_values: dict[str, Tensor] = {},
186
+ param_values: ParamDictType = {},
187
187
  state: Tensor | None = None,
188
188
  endianness: Endianness = Endianness.BIG,
189
189
  noise: NoiseHandler | None = None,
@@ -235,7 +235,7 @@ class Backend(BackendInterface):
235
235
  self,
236
236
  circuit: ConvertedCircuit,
237
237
  noise: NoiseHandler,
238
- param_values: dict[str, Tensor] = dict(),
238
+ param_values: ParamDictType = dict(),
239
239
  state: Tensor | None = None,
240
240
  endianness: Endianness = Endianness.BIG,
241
241
  ) -> Tensor:
@@ -284,7 +284,7 @@ class Backend(BackendInterface):
284
284
  def sample(
285
285
  self,
286
286
  circuit: ConvertedCircuit,
287
- param_values: dict[str, Tensor] = {},
287
+ param_values: ParamDictType = {},
288
288
  n_shots: int = 1,
289
289
  state: Tensor | None = None,
290
290
  noise: NoiseHandler | None = None,
@@ -322,7 +322,7 @@ class Backend(BackendInterface):
322
322
  self,
323
323
  circuit: ConvertedCircuit,
324
324
  observable: list[ConvertedObservable] | ConvertedObservable,
325
- param_values: dict[str, Tensor] = {},
325
+ param_values: ParamDictType = {},
326
326
  state: Tensor | None = None,
327
327
  measurement: Measurements | None = None,
328
328
  noise: NoiseHandler | None = None,
@@ -330,14 +330,19 @@ class Backend(BackendInterface):
330
330
  endianness: Endianness = Endianness.BIG,
331
331
  ) -> Tensor:
332
332
  observable = observable if isinstance(observable, list) else [observable]
333
+ param_circuit = param_values["circuit"] if "circuit" in param_values else param_values
334
+ param_observables = (
335
+ param_values["observables"] if "observables" in param_values else param_values
336
+ )
333
337
  if mitigation is None:
334
338
  if noise is None:
335
339
  state = self.run(
336
- circuit, param_values=param_values, state=state, endianness=endianness
340
+ circuit, param_values=param_circuit, state=state, endianness=endianness
337
341
  )
338
342
  support = sorted(list(circuit.abstract.register.support))
339
343
  res_list = [
340
- obs.native(state, param_values, qubit_support=support) for obs in observable
344
+ obs.native(state, param_observables, qubit_support=support)
345
+ for obs in observable
341
346
  ]
342
347
  res = torch.transpose(torch.stack(res_list), 0, 1)
343
348
  res = res if len(res.shape) > 0 else res.reshape(1)
@@ -345,7 +350,7 @@ class Backend(BackendInterface):
345
350
  elif noise is not None:
346
351
  dms = self.run(
347
352
  circuit=circuit,
348
- param_values=param_values,
353
+ param_values=param_circuit,
349
354
  state=state,
350
355
  endianness=endianness,
351
356
  noise=noise,
@@ -353,7 +358,9 @@ class Backend(BackendInterface):
353
358
  support = sorted(list(circuit.abstract.register.support))
354
359
  res_list = [
355
360
  [
356
- obs.native(dm.squeeze(), param_values, qubit_support=support, noise=noise)
361
+ obs.native(
362
+ dm.squeeze(), param_observables, qubit_support=support, noise=noise
363
+ )
357
364
  for dm in dms
358
365
  ]
359
366
  for obs in observable
@@ -66,6 +66,24 @@ class Configuration(BackendConfiguration):
66
66
  FIXME: To be deprecated.
67
67
  """
68
68
 
69
+ n_eqs: int | None = None
70
+ """Number of equations to use in aGPSR calculations."""
71
+
72
+ shift_prefac: float = 0.5
73
+ """Prefactor governing the magnitude of parameter shift values.
74
+
75
+ Select smaller value if spectral gaps are large.
76
+ """
77
+
78
+ gap_step: float = 1.0
79
+ """Step between generated pseudo-gaps when using aGPSR algorithm."""
80
+
81
+ lb: float | None = None
82
+ """Lower bound of optimal shift value search interval."""
83
+
84
+ ub: float | None = None
85
+ """Upper bound of optimal shift value search interval."""
86
+
69
87
  # configuration for cloud simulations
70
88
  cloud_configuration: Optional[CloudConfiguration] = None
71
89
 
@@ -31,7 +31,7 @@ from qadence.transpile import (
31
31
  set_noise,
32
32
  transpile,
33
33
  )
34
- from qadence.types import BackendName, Endianness, Engine
34
+ from qadence.types import BackendName, Endianness, Engine, ParamDictType
35
35
 
36
36
  from .config import Configuration, default_passes
37
37
  from .convert_ops import convert_block, convert_readout_noise
@@ -182,16 +182,21 @@ class Backend(BackendInterface):
182
182
  self,
183
183
  circuit: ConvertedCircuit,
184
184
  observable: list[ConvertedObservable] | ConvertedObservable,
185
- param_values: dict[str, Tensor] = {},
185
+ param_values: ParamDictType = {},
186
186
  state: Tensor | None = None,
187
187
  measurement: Measurements | None = None,
188
188
  noise: NoiseHandler | None = None,
189
189
  endianness: Endianness = Endianness.BIG,
190
190
  ) -> Tensor:
191
191
  set_block_and_readout_noises(circuit, noise, self.config)
192
+ param_circuit = param_values["circuit"] if "circuit" in param_values else param_values
193
+ param_observables = (
194
+ param_values["observables"] if "observables" in param_values else param_values
195
+ )
196
+
192
197
  state = self.run(
193
198
  circuit,
194
- param_values=param_values,
199
+ param_values=param_circuit,
195
200
  state=state,
196
201
  endianness=endianness,
197
202
  pyqify_state=True,
@@ -200,7 +205,7 @@ class Backend(BackendInterface):
200
205
  )
201
206
  observable = observable if isinstance(observable, list) else [observable]
202
207
  _expectation = torch.hstack(
203
- [obs.native.expectation(state, param_values).reshape(-1, 1) for obs in observable]
208
+ [obs.native.expectation(state, param_observables).reshape(-1, 1) for obs in observable]
204
209
  )
205
210
  return _expectation
206
211
 
@@ -208,7 +213,7 @@ class Backend(BackendInterface):
208
213
  self,
209
214
  circuit: ConvertedCircuit,
210
215
  observable: list[ConvertedObservable] | ConvertedObservable,
211
- param_values: dict[str, Tensor] = {},
216
+ param_values: ParamDictType = {},
212
217
  state: Tensor | None = None,
213
218
  measurement: Measurements | None = None,
214
219
  noise: NoiseHandler | None = None,
@@ -230,9 +235,18 @@ class Backend(BackendInterface):
230
235
 
231
236
  list_expvals = []
232
237
  observables = observable if isinstance(observable, list) else [observable]
233
- for vals in to_list_of_dicts(param_values):
234
- wf = self.run(circuit, vals, state, endianness, pyqify_state=True, unpyqify_state=False)
235
- exs = torch.cat([obs.native.expectation(wf, vals) for obs in observables], 0)
238
+ param_circuits = param_values["circuit"] if "circuit" in param_values else param_values
239
+ param_observables = (
240
+ param_values["observables"] if "observables" in param_values else param_values
241
+ )
242
+
243
+ for vals_circ, vals_obs in zip(
244
+ to_list_of_dicts(param_circuits), to_list_of_dicts(param_observables)
245
+ ):
246
+ wf = self.run(
247
+ circuit, vals_circ, state, endianness, pyqify_state=True, unpyqify_state=False
248
+ )
249
+ exs = torch.cat([obs.native.expectation(wf, vals_obs) for obs in observables], 0)
236
250
  list_expvals.append(exs)
237
251
 
238
252
  batch_expvals = torch.vstack(list_expvals)
@@ -242,7 +256,7 @@ class Backend(BackendInterface):
242
256
  self,
243
257
  circuit: ConvertedCircuit,
244
258
  observable: list[ConvertedObservable] | ConvertedObservable,
245
- param_values: dict[str, Tensor] = {},
259
+ param_values: ParamDictType = {},
246
260
  state: Tensor | None = None,
247
261
  measurement: Measurements | None = None,
248
262
  noise: NoiseHandler | None = None,
@@ -269,7 +283,7 @@ class Backend(BackendInterface):
269
283
  def sample(
270
284
  self,
271
285
  circuit: ConvertedCircuit,
272
- param_values: dict[str, Tensor] = {},
286
+ param_values: ParamDictType = {},
273
287
  n_shots: int = 1,
274
288
  state: Tensor | None = None,
275
289
  noise: NoiseHandler | None = None,
@@ -295,7 +309,7 @@ class Backend(BackendInterface):
295
309
  samples = apply_mitigation(noise=noise, mitigation=mitigation, samples=samples)
296
310
  return samples
297
311
 
298
- def assign_parameters(self, circuit: ConvertedCircuit, param_values: dict[str, Tensor]) -> Any:
312
+ def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any:
299
313
  raise NotImplementedError
300
314
 
301
315
  @staticmethod
@@ -70,3 +70,21 @@ class Configuration(BackendConfiguration):
70
70
  """Quantum dropout probability (0 means no dropout)."""
71
71
  dropout_mode: DropoutMode = DropoutMode.ROTATIONAL
72
72
  """Type of quantum dropout to perform."""
73
+
74
+ n_eqs: int | None = None
75
+ """Number of equations to use in aGPSR calculations."""
76
+
77
+ shift_prefac: float = 0.5
78
+ """Prefactor governing the magnitude of parameter shift values.
79
+
80
+ Select smaller value if spectral gaps are large.
81
+ """
82
+
83
+ gap_step: float = 1.0
84
+ """Step between generated pseudo-gaps when using aGPSR algorithm."""
85
+
86
+ lb: float | None = None
87
+ """Lower bound of optimal shift value search interval."""
88
+
89
+ ub: float | None = None
90
+ """Upper bound of optimal shift value search interval."""
@@ -16,7 +16,14 @@ from qadence.blocks.utils import (
16
16
  uuid_to_expression,
17
17
  )
18
18
  from qadence.parameters import evaluate, make_differentiable, stringify
19
- from qadence.types import ArrayLike, DifferentiableExpression, Engine, ParamDictType, TNumber
19
+ from qadence.types import (
20
+ ArrayLike,
21
+ DifferentiableExpression,
22
+ Engine,
23
+ ParamDictType,
24
+ TNumber,
25
+ )
26
+ from qadence.utils import merge_separate_params
20
27
 
21
28
 
22
29
  def _concretize_parameter(engine: Engine) -> Callable:
@@ -110,6 +117,8 @@ def embedding(
110
117
 
111
118
  def embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType:
112
119
  embedded_params: dict[sympy.Expr, ArrayLike] = {}
120
+ if "circuit" in inputs or "observables" in inputs:
121
+ inputs = merge_separate_params(inputs)
113
122
  for expr, fn in embeddings.items():
114
123
  angle: ArrayLike
115
124
  values = {}
@@ -187,9 +187,8 @@ class ParametricBlock(PrimitiveBlock):
187
187
  if not isinstance(other, AbstractBlock):
188
188
  raise TypeError(f"Cant compare {type(self)} to {type(other)}")
189
189
  if isinstance(other, type(self)):
190
- return (
191
- self.qubit_support == other.qubit_support
192
- and self.parameters.parameter == other.parameters.parameter
190
+ return self.qubit_support == other.qubit_support and self.parameters.parameter.equals(
191
+ other.parameters.parameter
193
192
  )
194
193
  return False
195
194
 
qadence/blocks/utils.py CHANGED
@@ -7,6 +7,7 @@ from logging import getLogger
7
7
  from typing import Generator, List, Type, TypeVar, Union, get_args
8
8
 
9
9
  from sympy import Array, Basic, Expr
10
+ import torch
10
11
  from torch import Tensor
11
12
 
12
13
  from qadence.blocks import (
@@ -292,31 +293,39 @@ def uuid_to_eigen(
292
293
 
293
294
  result = {}
294
295
  for uuid, b in uuid_to_block(block).items():
295
- if b.eigenvalues_generator is not None:
296
- if b.eigenvalues_generator.numel() > 0:
297
- # GPSR assumes a factor 0.5 for differentiation
298
- # so need rescaling
299
- if isinstance(b, TimeEvolutionBlock) and rescale_eigenvals_timeevo:
300
- if b.eigenvalues_generator.numel() > 1:
301
- result[uuid] = (
302
- b.eigenvalues_generator * 2.0,
303
- 0.5,
304
- )
296
+ eigs_generator = None
297
+
298
+ # this is to handle the case for the N operator
299
+ try:
300
+ eigs_generator = b.eigenvalues_generator
301
+ except ValueError:
302
+ result[uuid] = (torch.zeros(2), 1.0)
303
+ else:
304
+ if eigs_generator is not None:
305
+ if eigs_generator.numel() > 0:
306
+ # GPSR assumes a factor 0.5 for differentiation
307
+ # so need rescaling
308
+ if isinstance(b, TimeEvolutionBlock) and rescale_eigenvals_timeevo:
309
+ if eigs_generator.numel() > 1:
310
+ result[uuid] = (
311
+ eigs_generator * 2.0,
312
+ 0.5,
313
+ )
314
+ else:
315
+ result[uuid] = (
316
+ eigs_generator * 2.0,
317
+ (
318
+ 1.0 / (eigs_generator.item() * 2.0)
319
+ if len(eigs_generator) == 1
320
+ else 1.0
321
+ ),
322
+ )
305
323
  else:
306
- result[uuid] = (
307
- b.eigenvalues_generator * 2.0,
308
- (
309
- 1.0 / (b.eigenvalues_generator.item() * 2.0)
310
- if len(b.eigenvalues_generator) == 1
311
- else 1.0
312
- ),
313
- )
314
- else:
315
- result[uuid] = (b.eigenvalues_generator, 1.0)
316
-
317
- # leave only angle parameter uuid with eigenvals for ConstantAnalogRotation block
318
- if isinstance(block, ConstantAnalogRotation):
319
- break
324
+ result[uuid] = (eigs_generator, 1.0)
325
+
326
+ # leave only angle parameter uuid with eigenvals for ConstantAnalogRotation block
327
+ if isinstance(block, ConstantAnalogRotation):
328
+ break
320
329
 
321
330
  return result
322
331
 
@@ -12,7 +12,13 @@ from qadence.circuit import QuantumCircuit
12
12
  from qadence.measurements import Measurements
13
13
  from qadence.mitigations import Mitigations
14
14
  from qadence.noise import NoiseHandler
15
- from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
15
+ from qadence.types import (
16
+ ArrayLike,
17
+ DiffMode,
18
+ Endianness,
19
+ Engine,
20
+ ParamDictType,
21
+ )
16
22
 
17
23
 
18
24
  @dataclass(frozen=True, eq=True)
@@ -8,7 +8,13 @@ from qadence.engines.jax.differentiable_expectation import DifferentiableExpecta
8
8
  from qadence.measurements import Measurements
9
9
  from qadence.mitigations import Mitigations
10
10
  from qadence.noise import NoiseHandler
11
- from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
11
+ from qadence.types import (
12
+ ArrayLike,
13
+ DiffMode,
14
+ Endianness,
15
+ Engine,
16
+ ParamDictType,
17
+ )
12
18
 
13
19
 
14
20
  class DifferentiableBackend(DifferentiableBackendInterface):
@@ -8,11 +8,17 @@ from qadence.engines.differentiable_backend import (
8
8
  DifferentiableBackend as DifferentiableBackendInterface,
9
9
  )
10
10
  from qadence.engines.torch.differentiable_expectation import DifferentiableExpectation
11
- from qadence.extensions import get_gpsr_fns
11
+ from qadence.backends.parameter_shift_rules import general_psr
12
12
  from qadence.measurements import Measurements
13
13
  from qadence.mitigations import Mitigations
14
14
  from qadence.noise import NoiseHandler
15
- from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
15
+ from qadence.types import (
16
+ ArrayLike,
17
+ DiffMode,
18
+ Endianness,
19
+ Engine,
20
+ ParamDictType,
21
+ )
16
22
 
17
23
 
18
24
  class DifferentiableBackend(DifferentiableBackendInterface):
@@ -75,11 +81,8 @@ class DifferentiableBackend(DifferentiableBackendInterface):
75
81
  expectation = differentiable_expectation.ad
76
82
  elif self.diff_mode == DiffMode.ADJOINT:
77
83
  expectation = differentiable_expectation.adjoint
78
- else:
79
- try:
80
- fns = get_gpsr_fns()
81
- psr_fn = fns[self.diff_mode]
82
- except KeyError:
83
- raise ValueError(f"{self.diff_mode} differentiation mode is not supported")
84
- expectation = partial(differentiable_expectation.psr, psr_fn=psr_fn, **self.psr_args)
84
+ elif self.diff_mode == DiffMode.GPSR:
85
+ expectation = partial(
86
+ differentiable_expectation.psr, psr_fn=general_psr, **self.psr_args
87
+ )
85
88
  return expectation()
@@ -20,7 +20,7 @@ from qadence.measurements import Measurements
20
20
  from qadence.mitigations import Mitigations
21
21
  from qadence.ml_tools import promote_to_tensor
22
22
  from qadence.noise import NoiseHandler
23
- from qadence.types import Endianness
23
+ from qadence.types import Endianness, ParamDictType
24
24
 
25
25
 
26
26
  class PSRExpectation(Function):
@@ -94,7 +94,7 @@ class DifferentiableExpectation:
94
94
  backend: QuantumBackend
95
95
  circuit: ConvertedCircuit
96
96
  observable: list[ConvertedObservable] | ConvertedObservable
97
- param_values: dict[str, Tensor]
97
+ param_values: ParamDictType
98
98
  state: Tensor | None = None
99
99
  measurement: Measurements | None = None
100
100
  noise: NoiseHandler | None = None
@@ -135,8 +135,6 @@ class DifferentiableExpectation:
135
135
  self.observable = (
136
136
  self.observable if isinstance(self.observable, list) else [self.observable]
137
137
  )
138
- if len(self.observable) > 1:
139
- raise NotImplementedError("AdjointExpectation currently only supports one observable.")
140
138
 
141
139
  n_qubits = self.circuit.abstract.n_qubits
142
140
  values_batch_size = infer_batchsize(self.param_values)
@@ -150,18 +148,21 @@ class DifferentiableExpectation:
150
148
  else self.state
151
149
  )
152
150
  batch_size = max(values_batch_size, self.state.size(-1))
153
- return (
154
- AdjointExpectation.apply(
151
+
152
+ def expectation_fn(i: int) -> Tensor:
153
+ return AdjointExpectation.apply(
155
154
  self.circuit.native,
156
155
  self.state,
157
- self.observable[0].native, # Currently, adjoint only supports a single observable.
156
+ self.observable[i].native, # Currently, adjoint only supports a single observable.
158
157
  None,
159
158
  self.param_values.keys(),
160
159
  *self.param_values.values(),
161
- )
162
- .unsqueeze(1)
163
- .reshape(batch_size, 1)
164
- ) # we expect (batch_size, n_observables) shape
160
+ ).reshape(
161
+ batch_size, 1
162
+ ) # we expect (batch_size, n_observables) shape
163
+
164
+ expectation_list = [expectation_fn(i) for i in range(len(self.observable))]
165
+ return torch.vstack(expectation_list)
165
166
 
166
167
  def psr(self, psr_fn: Callable, **psr_args: int | float | None) -> Tensor:
167
168
  # wrapper which unpacks the parameters
qadence/extensions.py CHANGED
@@ -108,14 +108,6 @@ def _supported_gates(backend_name: str) -> list[TAbstractBlock]:
108
108
  return [getattr(operations, gate) for gate in _supported_gates]
109
109
 
110
110
 
111
- def _gpsr_fns() -> dict:
112
- """Fallback function for native Qadence GPSR functions if extensions is not present."""
113
- # avoid circular import
114
- from qadence.backends.gpsr import general_psr
115
-
116
- return {DiffMode.GPSR: general_psr}
117
-
118
-
119
111
  def _validate_diff_mode(backend: Backend, diff_mode: DiffMode) -> None:
120
112
  """Fallback function for native Qadence diff_mode if extensions is not present."""
121
113
  if not backend.supports_ad and diff_mode == DiffMode.AD:
@@ -152,11 +144,9 @@ try:
152
144
  available_backends = getattr(module, "available_backends")
153
145
  available_engines = getattr(module, "available_engines")
154
146
  supported_gates = getattr(module, "supported_gates")
155
- get_gpsr_fns = getattr(module, "gpsr_fns")
156
147
  set_backend_config = getattr(module, "set_backend_config")
157
148
  except ModuleNotFoundError:
158
149
  available_backends = _available_backends
159
150
  available_engines = _available_engines
160
151
  supported_gates = _supported_gates
161
- get_gpsr_fns = _gpsr_fns
162
152
  set_backend_config = _set_backend_config
@@ -10,6 +10,7 @@ from .optimize_step import optimize_step as default_optimize_step
10
10
  from .parameters import get_parameters, num_parameters, set_parameters
11
11
  from .tensors import numpy_to_tensor, promote_to, promote_to_tensor
12
12
  from .trainer import Trainer
13
+ from .qcnn_model import QCNN
13
14
 
14
15
  # Modules to be automatically added to the qadence namespace
15
16
  __all__ = [
@@ -25,4 +26,5 @@ __all__ = [
25
26
  "OptimizeResult",
26
27
  "Trainer",
27
28
  "write_checkpoint",
29
+ "QCNN",
28
30
  ]
@@ -115,8 +115,10 @@ class CallbacksManager:
115
115
  self.add_callback("PlotMetrics", "train_end")
116
116
  # only save the last checkpoint if not checkpoint_best_only
117
117
  if not self.config.checkpoint_best_only:
118
- self.add_callback("SaveCheckpoint", "train_end")
119
- self.add_callback("WriteMetrics", "train_end")
118
+ if self.config.checkpoint_every != 0:
119
+ self.add_callback("SaveCheckpoint", "train_end")
120
+ if self.config.write_every != 0:
121
+ self.add_callback("WriteMetrics", "train_end")
120
122
 
121
123
  def add_callback(
122
124
  self, callback: str | Callback, on: str | TrainingStage, called_every: int = 1