qadence 1.5.0__tar.gz → 1.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. {qadence-1.5.0 → qadence-1.5.1}/PKG-INFO +1 -1
  2. {qadence-1.5.0 → qadence-1.5.1}/pyproject.toml +1 -1
  3. {qadence-1.5.0 → qadence-1.5.1}/qadence/backend.py +1 -26
  4. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/braket/backend.py +1 -1
  5. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/horqrux/backend.py +1 -1
  6. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/backend.py +1 -1
  7. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pyqtorch/backend.py +1 -1
  8. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/models.py +17 -5
  9. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/train_grad.py +16 -7
  10. {qadence-1.5.0 → qadence-1.5.1}/qadence/models/quantum_model.py +5 -2
  11. {qadence-1.5.0 → qadence-1.5.1}/.coveragerc +0 -0
  12. {qadence-1.5.0 → qadence-1.5.1}/.github/dependabot.yml +0 -0
  13. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/build_docs.yml +0 -0
  14. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/dependabot.yml +0 -0
  15. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/lint.yml +0 -0
  16. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/test_all.yml +0 -0
  17. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/test_examples.yml +0 -0
  18. {qadence-1.5.0 → qadence-1.5.1}/.github/workflows/test_fast.yml +0 -0
  19. {qadence-1.5.0 → qadence-1.5.1}/.gitignore +0 -0
  20. {qadence-1.5.0 → qadence-1.5.1}/.pre-commit-config.yaml +0 -0
  21. {qadence-1.5.0 → qadence-1.5.1}/LICENSE +0 -0
  22. {qadence-1.5.0 → qadence-1.5.1}/MANIFEST.in +0 -0
  23. {qadence-1.5.0 → qadence-1.5.1}/README.md +0 -0
  24. {qadence-1.5.0 → qadence-1.5.1}/mkdocs.yml +0 -0
  25. {qadence-1.5.0 → qadence-1.5.1}/qadence/__init__.py +0 -0
  26. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/__init__.py +0 -0
  27. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/addressing.py +0 -0
  28. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/constants.py +0 -0
  29. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/device.py +0 -0
  30. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/hamiltonian_terms.py +0 -0
  31. {qadence-1.5.0 → qadence-1.5.1}/qadence/analog/parse_analog.py +0 -0
  32. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/__init__.py +0 -0
  33. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/adjoint.py +0 -0
  34. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/api.py +0 -0
  35. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/braket/__init__.py +0 -0
  36. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/braket/config.py +0 -0
  37. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/braket/convert_ops.py +0 -0
  38. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/gpsr.py +0 -0
  39. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/horqrux/__init__.py +0 -0
  40. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/horqrux/config.py +0 -0
  41. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/horqrux/convert_ops.py +0 -0
  42. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/jax_utils.py +0 -0
  43. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/__init__.py +0 -0
  44. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/channels.py +0 -0
  45. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/cloud.py +0 -0
  46. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/config.py +0 -0
  47. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/convert_ops.py +0 -0
  48. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/devices.py +0 -0
  49. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/pulses.py +0 -0
  50. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pulser/waveforms.py +0 -0
  51. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pyqtorch/__init__.py +0 -0
  52. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pyqtorch/config.py +0 -0
  53. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/pyqtorch/convert_ops.py +0 -0
  54. {qadence-1.5.0 → qadence-1.5.1}/qadence/backends/utils.py +0 -0
  55. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/__init__.py +0 -0
  56. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/abstract.py +0 -0
  57. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/analog.py +0 -0
  58. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/block_to_tensor.py +0 -0
  59. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/composite.py +0 -0
  60. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/embedding.py +0 -0
  61. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/manipulate.py +0 -0
  62. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/matrix.py +0 -0
  63. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/primitive.py +0 -0
  64. {qadence-1.5.0 → qadence-1.5.1}/qadence/blocks/utils.py +0 -0
  65. {qadence-1.5.0 → qadence-1.5.1}/qadence/circuit.py +0 -0
  66. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/__init__.py +0 -0
  67. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/ansatze.py +0 -0
  68. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/daqc/__init__.py +0 -0
  69. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/daqc/daqc.py +0 -0
  70. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/daqc/gen_parser.py +0 -0
  71. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/daqc/utils.py +0 -0
  72. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/feature_maps.py +0 -0
  73. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/hamiltonians.py +0 -0
  74. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/iia.py +0 -0
  75. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/qft.py +0 -0
  76. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/rydberg_feature_maps.py +0 -0
  77. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/rydberg_hea.py +0 -0
  78. {qadence-1.5.0 → qadence-1.5.1}/qadence/constructors/utils.py +0 -0
  79. {qadence-1.5.0 → qadence-1.5.1}/qadence/decompose.py +0 -0
  80. {qadence-1.5.0 → qadence-1.5.1}/qadence/divergences.py +0 -0
  81. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/__init__.py +0 -0
  82. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/assets/dark/measurement.png +0 -0
  83. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/assets/dark/measurement.svg +0 -0
  84. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/assets/light/measurement.png +0 -0
  85. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/assets/light/measurement.svg +0 -0
  86. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/themes.py +0 -0
  87. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/utils.py +0 -0
  88. {qadence-1.5.0 → qadence-1.5.1}/qadence/draw/vizbackend.py +0 -0
  89. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/__init__.py +0 -0
  90. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/differentiable_backend.py +0 -0
  91. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/jax/__init__.py +0 -0
  92. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/jax/differentiable_backend.py +0 -0
  93. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/jax/differentiable_expectation.py +0 -0
  94. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/torch/__init__.py +0 -0
  95. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/torch/differentiable_backend.py +0 -0
  96. {qadence-1.5.0 → qadence-1.5.1}/qadence/engines/torch/differentiable_expectation.py +0 -0
  97. {qadence-1.5.0 → qadence-1.5.1}/qadence/exceptions/__init__.py +0 -0
  98. {qadence-1.5.0 → qadence-1.5.1}/qadence/exceptions/exceptions.py +0 -0
  99. {qadence-1.5.0 → qadence-1.5.1}/qadence/execution.py +0 -0
  100. {qadence-1.5.0 → qadence-1.5.1}/qadence/extensions.py +0 -0
  101. {qadence-1.5.0 → qadence-1.5.1}/qadence/finitediff.py +0 -0
  102. {qadence-1.5.0 → qadence-1.5.1}/qadence/libs.py +0 -0
  103. {qadence-1.5.0 → qadence-1.5.1}/qadence/logger.py +0 -0
  104. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/__init__.py +0 -0
  105. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/protocols.py +0 -0
  106. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/samples.py +0 -0
  107. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/shadow.py +0 -0
  108. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/tomography.py +0 -0
  109. {qadence-1.5.0 → qadence-1.5.1}/qadence/measurements/utils.py +0 -0
  110. {qadence-1.5.0 → qadence-1.5.1}/qadence/mitigations/__init__.py +0 -0
  111. {qadence-1.5.0 → qadence-1.5.1}/qadence/mitigations/analog_zne.py +0 -0
  112. {qadence-1.5.0 → qadence-1.5.1}/qadence/mitigations/protocols.py +0 -0
  113. {qadence-1.5.0 → qadence-1.5.1}/qadence/mitigations/readout.py +0 -0
  114. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/__init__.py +0 -0
  115. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/config.py +0 -0
  116. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/data.py +0 -0
  117. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/optimize_step.py +0 -0
  118. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/parameters.py +0 -0
  119. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/printing.py +0 -0
  120. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/saveload.py +0 -0
  121. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/tensors.py +0 -0
  122. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/train_no_grad.py +0 -0
  123. {qadence-1.5.0 → qadence-1.5.1}/qadence/ml_tools/utils.py +0 -0
  124. {qadence-1.5.0 → qadence-1.5.1}/qadence/models/__init__.py +0 -0
  125. {qadence-1.5.0 → qadence-1.5.1}/qadence/models/qnn.py +0 -0
  126. {qadence-1.5.0 → qadence-1.5.1}/qadence/noise/__init__.py +0 -0
  127. {qadence-1.5.0 → qadence-1.5.1}/qadence/noise/protocols.py +0 -0
  128. {qadence-1.5.0 → qadence-1.5.1}/qadence/noise/readout.py +0 -0
  129. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/__init__.py +0 -0
  130. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/analog.py +0 -0
  131. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/control_ops.py +0 -0
  132. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/ham_evo.py +0 -0
  133. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/parametric.py +0 -0
  134. {qadence-1.5.0 → qadence-1.5.1}/qadence/operations/primitive.py +0 -0
  135. {qadence-1.5.0 → qadence-1.5.1}/qadence/overlap.py +0 -0
  136. {qadence-1.5.0 → qadence-1.5.1}/qadence/parameters.py +0 -0
  137. {qadence-1.5.0 → qadence-1.5.1}/qadence/protocols.py +0 -0
  138. {qadence-1.5.0 → qadence-1.5.1}/qadence/py.typed +0 -0
  139. {qadence-1.5.0 → qadence-1.5.1}/qadence/qubit_support.py +0 -0
  140. {qadence-1.5.0 → qadence-1.5.1}/qadence/register.py +0 -0
  141. {qadence-1.5.0 → qadence-1.5.1}/qadence/serialization.py +0 -0
  142. {qadence-1.5.0 → qadence-1.5.1}/qadence/states.py +0 -0
  143. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/__init__.py +0 -0
  144. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/apply_fn.py +0 -0
  145. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/block.py +0 -0
  146. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/circuit.py +0 -0
  147. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/digitalize.py +0 -0
  148. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/flatten.py +0 -0
  149. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/invert.py +0 -0
  150. {qadence-1.5.0 → qadence-1.5.1}/qadence/transpile/transpile.py +0 -0
  151. {qadence-1.5.0 → qadence-1.5.1}/qadence/types.py +0 -0
  152. {qadence-1.5.0 → qadence-1.5.1}/qadence/utils.py +0 -0
  153. {qadence-1.5.0 → qadence-1.5.1}/renovate.json +0 -0
  154. {qadence-1.5.0 → qadence-1.5.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: qadence
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: Pasqal interface for circuit-based quantum computing SDKs
5
5
  Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>
6
6
  License: Apache 2.0
@@ -20,7 +20,7 @@ authors = [
20
20
  ]
21
21
  requires-python = ">=3.9,<3.13"
22
22
  license = {text = "Apache 2.0"}
23
- version = "1.5.0"
23
+ version = "1.5.1"
24
24
  classifiers=[
25
25
  "License :: OSI Approved :: Apache Software License",
26
26
  "Programming Language :: Python",
@@ -26,7 +26,6 @@ from qadence.mitigations import Mitigations
26
26
  from qadence.noise import Noise
27
27
  from qadence.parameters import stringify
28
28
  from qadence.types import ArrayLike, BackendName, DiffMode, Endianness, Engine, ParamDictType
29
- from qadence.utils import validate_values_and_state
30
29
 
31
30
  logger = get_logger(__file__)
32
31
 
@@ -259,29 +258,6 @@ class Backend(ABC):
259
258
  """
260
259
  raise NotImplementedError
261
260
 
262
- @abstractmethod
263
- def _run(
264
- self,
265
- circuit: ConvertedCircuit,
266
- param_values: dict[str, ArrayLike] = {},
267
- state: ArrayLike | None = None,
268
- endianness: Endianness = Endianness.BIG,
269
- ) -> ArrayLike:
270
- """Run a circuit and return the resulting wave function.
271
-
272
- Arguments:
273
- circuit: A converted circuit as returned by `backend.circuit`.
274
- param_values: _**Already embedded**_ parameters of the circuit. See
275
- [`embedding`][qadence.blocks.embedding.embedding] for more info.
276
- state: Initial state.
277
- endianness: Endianness of the resulting wavefunction.
278
-
279
- Returns:
280
- A list of Counter objects where each key represents a bitstring
281
- and its value the number of times it has been sampled from the given wave function.
282
- """
283
- raise NotImplementedError
284
-
285
261
  def run(
286
262
  self,
287
263
  circuit: ConvertedCircuit,
@@ -304,8 +280,7 @@ class Backend(ABC):
304
280
  A list of Counter objects where each key represents a bitstring
305
281
  and its value the number of times it has been sampled from the given wave function.
306
282
  """
307
- validate_values_and_state(state, circuit.abstract.n_qubits, param_values)
308
- return self._run(circuit, param_values, state, endianness, *args, **kwargs)
283
+ raise NotImplementedError
309
284
 
310
285
  @abstractmethod
311
286
  def run_dm(
@@ -88,7 +88,7 @@ class Backend(BackendInterface):
88
88
  ).squeeze(0)
89
89
  return ConvertedObservable(native=native, abstract=obs, original=obs)
90
90
 
91
- def _run(
91
+ def run(
92
92
  self,
93
93
  circuit: ConvertedCircuit,
94
94
  param_values: dict[str, Tensor] = {},
@@ -66,7 +66,7 @@ class Backend(BackendInterface):
66
66
  hq_obs = convert_observable(block, n_qubits=n_qubits, config=self.config)
67
67
  return ConvertedObservable(native=hq_obs, abstract=block, original=observable)
68
68
 
69
- def _run(
69
+ def run(
70
70
  self,
71
71
  circuit: ConvertedCircuit,
72
72
  param_values: ParamDictType = {},
@@ -200,7 +200,7 @@ class Backend(BackendInterface):
200
200
 
201
201
  return circuit.native.build(**numpy_param_values)
202
202
 
203
- def _run(
203
+ def run(
204
204
  self,
205
205
  circuit: ConvertedCircuit,
206
206
  param_values: dict[str, Tensor] = {},
@@ -80,7 +80,7 @@ class Backend(BackendInterface):
80
80
  (native,) = convert_observable(block, n_qubits=n_qubits, config=self.config)
81
81
  return ConvertedObservable(native=native, abstract=block, original=observable)
82
82
 
83
- def _run(
83
+ def run(
84
84
  self,
85
85
  circuit: ConvertedCircuit,
86
86
  param_values: dict[str, Tensor] = {},
@@ -289,11 +289,23 @@ class TransformedModule(torch.nn.Module):
289
289
  def to(self, *args: Any, **kwargs: Any) -> TransformedModule:
290
290
  try:
291
291
  self.model = self.model.to(*args, **kwargs)
292
- self._input_scaling = self._input_scaling.to(*args, **kwargs)
293
- self._input_shifting = self._input_shifting.to(*args, **kwargs)
294
- self._output_scaling = self._output_scaling.to(*args, **kwargs)
295
- self._output_shifting = self._output_shifting.to(*args, **kwargs)
296
-
292
+ if isinstance(self.model, QuantumModel):
293
+ device = self.model._circuit.native.device
294
+ dtype = (
295
+ torch.float64
296
+ if self.model._circuit.native.dtype == torch.cdouble
297
+ else torch.float32
298
+ )
299
+
300
+ self._input_scaling = self._input_scaling.to(device=device, dtype=dtype)
301
+ self._input_shifting = self._input_shifting.to(device=device, dtype=dtype)
302
+ self._output_scaling = self._output_scaling.to(device=device, dtype=dtype)
303
+ self._output_shifting = self._output_shifting.to(device=device, dtype=dtype)
304
+ elif isinstance(self.model, torch.nn.Module):
305
+ self._input_scaling = self._input_scaling.to(*args, **kwargs)
306
+ self._input_shifting = self._input_shifting.to(*args, **kwargs)
307
+ self._output_scaling = self._output_scaling.to(*args, **kwargs)
308
+ self._output_shifting = self._output_shifting.to(*args, **kwargs)
297
309
  logger.debug(f"Moved {self} to {args}, {kwargs}.")
298
310
  except Exception as e:
299
311
  logger.warning(f"Unable to move {self} to {args}, {kwargs} due to {e}.")
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import Callable, Union
4
4
 
5
5
  from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
6
+ from torch import complex128, float32, float64
6
7
  from torch import device as torch_device
7
8
  from torch import dtype as torch_dtype
8
9
  from torch.nn import DataParallel, Module
@@ -110,17 +111,17 @@ def train(
110
111
  train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
111
112
  ```
112
113
  """
114
+ # load available checkpoint
115
+ init_iter = 0
116
+ if config.folder:
117
+ model, optimizer, init_iter = load_checkpoint(config.folder, model, optimizer)
118
+ logger.debug(f"Loaded model and optimizer from {config.folder}")
113
119
 
114
120
  # Move model to device before optimizer is loaded
115
121
  if isinstance(model, DataParallel):
116
122
  model = model.module.to(device=device, dtype=dtype)
117
123
  else:
118
124
  model = model.to(device=device, dtype=dtype)
119
- # load available checkpoint
120
- init_iter = 0
121
- if config.folder:
122
- model, optimizer, init_iter = load_checkpoint(config.folder, model, optimizer)
123
- logger.debug(f"Loaded model and optimizer from {config.folder}")
124
125
  # initialize tensorboard
125
126
  writer = SummaryWriter(config.folder, purge_step=init_iter)
126
127
 
@@ -131,7 +132,9 @@ def train(
131
132
  TaskProgressColumn(),
132
133
  TimeRemainingColumn(elapsed_when_finished=True),
133
134
  )
134
-
135
+ data_dtype = None
136
+ if dtype:
137
+ data_dtype = float64 if dtype == complex128 else float32
135
138
  with progress:
136
139
  dl_iter = iter(dataloader) if dataloader is not None else None
137
140
 
@@ -143,7 +146,12 @@ def train(
143
146
  # which do not have classical input data (e.g. chemistry)
144
147
  if dataloader is None:
145
148
  loss, metrics = optimize_step(
146
- model=model, optimizer=optimizer, loss_fn=loss_fn, xs=None, device=device
149
+ model=model,
150
+ optimizer=optimizer,
151
+ loss_fn=loss_fn,
152
+ xs=None,
153
+ device=device,
154
+ dtype=data_dtype,
147
155
  )
148
156
  loss = loss.item()
149
157
 
@@ -154,6 +162,7 @@ def train(
154
162
  loss_fn=loss_fn,
155
163
  xs=next(dl_iter), # type: ignore[arg-type]
156
164
  device=device,
165
+ dtype=data_dtype,
157
166
  )
158
167
 
159
168
  else:
@@ -342,9 +342,10 @@ class QuantumModel(nn.Module):
342
342
  return self.backend.assign_parameters(self._circuit, params)
343
343
 
344
344
  def to(self, *args: Any, **kwargs: Any) -> QuantumModel:
345
+ from pyqtorch import QuantumCircuit as PyQCircuit
346
+
345
347
  try:
346
- if isinstance(self._circuit.native, torch.nn.Module):
347
- # Backends which are not torch-based cannot be moved to 'device'
348
+ if isinstance(self._circuit.native, PyQCircuit):
348
349
  self._circuit.native = self._circuit.native.to(*args, **kwargs)
349
350
  if self._observable is not None:
350
351
  if isinstance(self._observable, ConvertedObservable):
@@ -359,6 +360,8 @@ class QuantumModel(nn.Module):
359
360
  else torch.float32,
360
361
  )
361
362
  logger.debug(f"Moved {self} to {args}, {kwargs}.")
363
+ else:
364
+ logger.debug("QuantumModel.to only supports pyqtorch.QuantumCircuits.")
362
365
  except Exception as e:
363
366
  logger.warning(f"Unable to move {self} to {args}, {kwargs} due to {e}.")
364
367
  return self
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes