qadence 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. qadence/__init__.py +1 -0
  2. qadence/backend.py +18 -14
  3. qadence/backends/__init__.py +1 -2
  4. qadence/backends/api.py +27 -9
  5. qadence/backends/braket/backend.py +2 -1
  6. qadence/backends/horqrux/__init__.py +5 -0
  7. qadence/backends/horqrux/backend.py +216 -0
  8. qadence/backends/horqrux/config.py +26 -0
  9. qadence/backends/horqrux/convert_ops.py +273 -0
  10. qadence/backends/jax_utils.py +45 -0
  11. qadence/backends/pulser/backend.py +2 -1
  12. qadence/backends/pyqtorch/backend.py +2 -1
  13. qadence/backends/pyqtorch/convert_ops.py +3 -3
  14. qadence/backends/utils.py +8 -3
  15. qadence/blocks/embedding.py +46 -24
  16. qadence/blocks/utils.py +20 -1
  17. qadence/circuit.py +3 -9
  18. qadence/engines/__init__.py +0 -0
  19. qadence/engines/differentiable_backend.py +152 -0
  20. qadence/engines/jax/__init__.py +8 -0
  21. qadence/engines/jax/differentiable_backend.py +73 -0
  22. qadence/engines/jax/differentiable_expectation.py +94 -0
  23. qadence/engines/torch/__init__.py +4 -0
  24. qadence/engines/torch/differentiable_backend.py +85 -0
  25. qadence/{backends/pytorch_wrapper.py → engines/torch/differentiable_expectation.py} +1 -2
  26. qadence/extensions.py +20 -3
  27. qadence/ml_tools/models.py +10 -3
  28. qadence/models/quantum_model.py +13 -2
  29. qadence/parameters.py +19 -16
  30. qadence/types.py +13 -1
  31. {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/METADATA +21 -2
  32. {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/RECORD +34 -22
  33. {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/WHEEL +1 -1
  34. {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,152 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import Counter
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable
7
+
8
+ from qadence.backend import Backend, Converted, ConvertedCircuit, ConvertedObservable
9
+ from qadence.blocks import AbstractBlock, PrimitiveBlock
10
+ from qadence.blocks.utils import uuid_to_block
11
+ from qadence.circuit import QuantumCircuit
12
+ from qadence.measurements import Measurements
13
+ from qadence.mitigations import Mitigations
14
+ from qadence.noise import Noise
15
+ from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
16
+
17
+
18
+ @dataclass(frozen=True, eq=True)
19
+ class DifferentiableBackend(ABC):
20
+ """The abstract class which wraps any (non)-natively differentiable QuantumBackend.
21
+
22
+ in an automatic differentiation engine.
23
+
24
+ Arguments:
25
+ backend: An instance of the QuantumBackend type perform execution.
26
+ engine: Which automatic differentiation engine the QuantumBackend runs on.
27
+ diff_mode: A differentiable mode supported by the differentiation engine.
28
+ """
29
+
30
+ backend: Backend
31
+ engine: Engine
32
+ diff_mode: DiffMode
33
+
34
+ # TODO: Add differentiable overlap calculation
35
+ _overlap: Callable = None # type: ignore [assignment]
36
+
37
+ def sample(
38
+ self,
39
+ circuit: ConvertedCircuit,
40
+ param_values: ParamDictType = {},
41
+ n_shots: int = 100,
42
+ state: ArrayLike | None = None,
43
+ noise: Noise | None = None,
44
+ mitigation: Mitigations | None = None,
45
+ endianness: Endianness = Endianness.BIG,
46
+ ) -> list[Counter]:
47
+ """Sample bitstring from the registered circuit.
48
+
49
+ Arguments:
50
+ circuit: A backend native quantum circuit to be executed.
51
+ param_values: The values of the parameters after embedding
52
+ n_shots: The number of shots. Defaults to 1.
53
+ state: Initial state.
54
+ noise: A noise model to use.
55
+ mitigation: A mitigation protocol to apply to noisy samples.
56
+ endianness: Endianness of the resulting bitstrings.
57
+
58
+ Returns:
59
+ An iterable with all the sampled bitstrings
60
+ """
61
+
62
+ return self.backend.sample(
63
+ circuit=circuit,
64
+ param_values=param_values,
65
+ n_shots=n_shots,
66
+ state=state,
67
+ noise=noise,
68
+ mitigation=mitigation,
69
+ endianness=endianness,
70
+ )
71
+
72
+ def run(
73
+ self,
74
+ circuit: ConvertedCircuit,
75
+ param_values: ParamDictType = {},
76
+ state: ArrayLike | None = None,
77
+ endianness: Endianness = Endianness.BIG,
78
+ ) -> ArrayLike:
79
+ """Run on the underlying backend."""
80
+ return self.backend.run(
81
+ circuit=circuit, param_values=param_values, state=state, endianness=endianness
82
+ )
83
+
84
+ @abstractmethod
85
+ def expectation(
86
+ self,
87
+ circuit: ConvertedCircuit,
88
+ observable: list[ConvertedObservable] | ConvertedObservable,
89
+ param_values: ParamDictType = {},
90
+ state: ArrayLike | None = None,
91
+ measurement: Measurements | None = None,
92
+ noise: Noise | None = None,
93
+ mitigation: Mitigations | None = None,
94
+ endianness: Endianness = Endianness.BIG,
95
+ ) -> Any:
96
+ """Compute the expectation value of the `circuit` with the given `observable`.
97
+
98
+ Arguments:
99
+ circuit: A converted circuit as returned by `backend.circuit`.
100
+ observable: A converted observable as returned by `backend.observable`.
101
+ param_values: _**Already embedded**_ parameters of the circuit. See
102
+ [`embedding`][qadence.blocks.embedding.embedding] for more info.
103
+ state: Initial state.
104
+ measurement: Optional measurement protocol. If None, use
105
+ exact expectation value with a statevector simulator.
106
+ noise: A noise model to use.
107
+ mitigation: The error mitigation to use.
108
+ endianness: Endianness of the resulting bit strings.
109
+ """
110
+ raise NotImplementedError(
111
+ "A DifferentiableBackend needs to override the expectation method."
112
+ )
113
+
114
+ def default_configuration(self) -> Any:
115
+ return self.backend.default_configuration()
116
+
117
+ def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit:
118
+ if self.diff_mode == DiffMode.GPSR:
119
+ parametrized_blocks = list(uuid_to_block(circuit.block).values())
120
+ non_prim_blocks = filter(
121
+ lambda b: not isinstance(b, PrimitiveBlock), parametrized_blocks
122
+ )
123
+ if len(list(non_prim_blocks)) > 0:
124
+ raise ValueError(
125
+ "The circuit contains non-primitive blocks that are currently\
126
+ not supported by the PSR differentiable mode."
127
+ )
128
+ return self.backend.circuit(circuit)
129
+
130
+ def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable:
131
+ if self.diff_mode != DiffMode.AD and observable is not None:
132
+ msg = (
133
+ f"Differentiation mode '{self.diff_mode}' does not support parametric observables."
134
+ )
135
+ if isinstance(observable, list):
136
+ for obs in observable:
137
+ if obs.is_parametric:
138
+ raise ValueError(msg)
139
+ else:
140
+ if observable.is_parametric:
141
+ raise ValueError(msg)
142
+ return self.backend.observable(observable, n_qubits)
143
+
144
+ def convert(
145
+ self,
146
+ circuit: QuantumCircuit,
147
+ observable: list[AbstractBlock] | AbstractBlock | None = None,
148
+ ) -> Converted:
149
+ return self.backend.convert(circuit, observable)
150
+
151
+ def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any:
152
+ return self.backend.assign_parameters(circuit, param_values)
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from jax import config
4
+
5
+ from .differentiable_backend import DifferentiableBackend
6
+ from .differentiable_expectation import DifferentiableExpectation
7
+
8
+ config.update("jax_enable_x64", True)
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from qadence.backend import Backend, ConvertedCircuit, ConvertedObservable
4
+ from qadence.engines.differentiable_backend import (
5
+ DifferentiableBackend as DifferentiableBackendInterface,
6
+ )
7
+ from qadence.engines.jax.differentiable_expectation import DifferentiableExpectation
8
+ from qadence.measurements import Measurements
9
+ from qadence.mitigations import Mitigations
10
+ from qadence.noise import Noise
11
+ from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
12
+
13
+
14
+ class DifferentiableBackend(DifferentiableBackendInterface):
15
+ """A class which wraps a QuantumBackend with the automatic differentation engine JAX.
16
+
17
+ Arguments:
18
+ backend: An instance of the QuantumBackend type perform execution.
19
+ diff_mode: A differentiable mode supported by the differentiation engine.
20
+ **psr_args: Arguments that will be passed on to `DifferentiableExpectation`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ backend: Backend,
26
+ diff_mode: DiffMode = DiffMode.AD,
27
+ **psr_args: int | float | None,
28
+ ) -> None:
29
+ super().__init__(backend=backend, engine=Engine.JAX, diff_mode=diff_mode)
30
+ self.psr_args = psr_args
31
+
32
+ def expectation(
33
+ self,
34
+ circuit: ConvertedCircuit,
35
+ observable: list[ConvertedObservable] | ConvertedObservable,
36
+ param_values: ParamDictType = {},
37
+ state: ArrayLike | None = None,
38
+ measurement: Measurements | None = None,
39
+ noise: Noise | None = None,
40
+ mitigation: Mitigations | None = None,
41
+ endianness: Endianness = Endianness.BIG,
42
+ ) -> ArrayLike:
43
+ """Compute the expectation value of the `circuit` with the given `observable`.
44
+
45
+ Arguments:
46
+ circuit: A converted circuit as returned by `backend.circuit`.
47
+ observable: A converted observable as returned by `backend.observable`.
48
+ param_values: _**Already embedded**_ parameters of the circuit. See
49
+ [`embedding`][qadence.blocks.embedding.embedding] for more info.
50
+ state: Initial state.
51
+ measurement: Optional measurement protocol. If None, use
52
+ exact expectation value with a statevector simulator.
53
+ noise: A noise model to use.
54
+ mitigation: The error mitigation to use.
55
+ endianness: Endianness of the resulting bit strings.
56
+ """
57
+ observable = observable if isinstance(observable, list) else [observable]
58
+
59
+ if self.diff_mode == DiffMode.AD:
60
+ expectation = self.backend.expectation(circuit, observable, param_values, state)
61
+ else:
62
+ expectation = DifferentiableExpectation(
63
+ backend=self.backend,
64
+ circuit=circuit,
65
+ observable=observable,
66
+ param_values=param_values,
67
+ state=state,
68
+ measurement=measurement,
69
+ noise=noise,
70
+ mitigation=mitigation,
71
+ endianness=endianness,
72
+ ).psr()
73
+ return expectation
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Tuple
5
+
6
+ import jax.numpy as jnp
7
+ from jax import Array, custom_vjp, vmap
8
+
9
+ from qadence.backend import Backend as QuantumBackend
10
+ from qadence.backend import ConvertedCircuit, ConvertedObservable
11
+ from qadence.backends.jax_utils import (
12
+ tensor_to_jnp,
13
+ )
14
+ from qadence.blocks.utils import uuid_to_eigen
15
+ from qadence.measurements import Measurements
16
+ from qadence.mitigations import Mitigations
17
+ from qadence.noise import Noise
18
+ from qadence.types import Endianness, Engine, ParamDictType
19
+
20
+
21
+ def compute_single_gap(eigen_vals: Array, default_val: float = 2.0) -> Array:
22
+ eigen_vals = eigen_vals.reshape(1, 2)
23
+ gaps = jnp.abs(jnp.tril(eigen_vals.T - eigen_vals))
24
+ return jnp.unique(jnp.where(gaps > 0.0, gaps, default_val), size=1)
25
+
26
+
27
+ @dataclass
28
+ class DifferentiableExpectation:
29
+ """A handler for differentiating expectation estimation using various engines."""
30
+
31
+ backend: QuantumBackend
32
+ circuit: ConvertedCircuit
33
+ observable: list[ConvertedObservable] | ConvertedObservable
34
+ param_values: ParamDictType
35
+ state: Array | None = None
36
+ measurement: Measurements | None = None
37
+ noise: Noise | None = None
38
+ mitigation: Mitigations | None = None
39
+ endianness: Endianness = Endianness.BIG
40
+ engine: Engine = Engine.JAX
41
+
42
+ def psr(self) -> Any:
43
+ n_obs = len(self.observable)
44
+
45
+ def expectation_fn(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array:
46
+ return self.backend.expectation(
47
+ circuit=self.circuit, observable=self.observable, param_values=values, state=state
48
+ )
49
+
50
+ @custom_vjp
51
+ def expectation(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array:
52
+ return expectation_fn(state, values, psr_params)
53
+
54
+ uuid_to_eigs = {
55
+ k: tensor_to_jnp(v) for k, v in uuid_to_eigen(self.circuit.abstract.block).items()
56
+ }
57
+ self.psr_params = {
58
+ k: self.param_values[k] for k in uuid_to_eigs.keys()
59
+ } # Subset of params on which to perform PSR.
60
+
61
+ def expectation_fwd(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Any:
62
+ return expectation_fn(state, values, psr_params), (
63
+ state,
64
+ values,
65
+ psr_params,
66
+ )
67
+
68
+ def expectation_bwd(res: Tuple[Array, ParamDictType, ParamDictType], tangent: Array) -> Any:
69
+ state, values, psr_params = res
70
+ grads = {}
71
+ # FIXME Hardcoding the single spectral_gap to 2.
72
+ spectral_gap = 2.0
73
+ shift = jnp.pi / 2
74
+
75
+ def shift_circ(param_name: str, values: dict) -> Array:
76
+ shifted_values = values.copy()
77
+ shiftvals = jnp.array(
78
+ [shifted_values[param_name] + shift, shifted_values[param_name] - shift]
79
+ )
80
+
81
+ def _expectation(val: Array) -> Array:
82
+ shifted_values[param_name] = val
83
+ return expectation(state, shifted_values, psr_params)
84
+
85
+ return vmap(_expectation, in_axes=(0,))(shiftvals)
86
+
87
+ for param_name, _ in psr_params.items():
88
+ f_plus, f_min = shift_circ(param_name, values)
89
+ grad = spectral_gap * (f_plus - f_min) / (4.0 * jnp.sin(spectral_gap * shift / 2.0))
90
+ grads[param_name] = jnp.sum(tangent * grad, axis=1) if n_obs > 1 else tangent * grad
91
+ return None, None, grads
92
+
93
+ expectation.defvjp(expectation_fwd, expectation_bwd)
94
+ return expectation(self.state, self.param_values, self.psr_params)
@@ -0,0 +1,4 @@
1
+ from __future__ import annotations
2
+
3
+ from .differentiable_backend import DifferentiableBackend
4
+ from .differentiable_expectation import DifferentiableExpectation
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+
5
+ from qadence.backend import Backend as QuantumBackend
6
+ from qadence.backend import ConvertedCircuit, ConvertedObservable
7
+ from qadence.engines.differentiable_backend import (
8
+ DifferentiableBackend as DifferentiableBackendInterface,
9
+ )
10
+ from qadence.engines.torch.differentiable_expectation import DifferentiableExpectation
11
+ from qadence.extensions import get_gpsr_fns
12
+ from qadence.measurements import Measurements
13
+ from qadence.mitigations import Mitigations
14
+ from qadence.noise import Noise
15
+ from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
16
+
17
+
18
+ class DifferentiableBackend(DifferentiableBackendInterface):
19
+ """A class which wraps a QuantumBackend with the automatic differentation engine TORCH.
20
+
21
+ Arguments:
22
+ backend: An instance of the QuantumBackend type perform execution.
23
+ diff_mode: A differentiable mode supported by the differentiation engine.
24
+ **psr_args: Arguments that will be passed on to `DifferentiableExpectation`.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ backend: QuantumBackend,
30
+ diff_mode: DiffMode = DiffMode.AD,
31
+ **psr_args: int | float | None,
32
+ ) -> None:
33
+ super().__init__(backend=backend, engine=Engine.TORCH, diff_mode=diff_mode)
34
+ self.psr_args = psr_args
35
+
36
+ def expectation(
37
+ self,
38
+ circuit: ConvertedCircuit,
39
+ observable: list[ConvertedObservable] | ConvertedObservable,
40
+ param_values: ParamDictType = {},
41
+ state: ArrayLike | None = None,
42
+ measurement: Measurements | None = None,
43
+ noise: Noise | None = None,
44
+ mitigation: Mitigations | None = None,
45
+ endianness: Endianness = Endianness.BIG,
46
+ ) -> ArrayLike:
47
+ """Compute the expectation value of the `circuit` with the given `observable`.
48
+
49
+ Arguments:
50
+ circuit: A converted circuit as returned by `backend.circuit`.
51
+ observable: A converted observable as returned by `backend.observable`.
52
+ param_values: _**Already embedded**_ parameters of the circuit. See
53
+ [`embedding`][qadence.blocks.embedding.embedding] for more info.
54
+ state: Initial state.
55
+ measurement: Optional measurement protocol. If None, use
56
+ exact expectation value with a statevector simulator.
57
+ noise: A noise model to use.
58
+ mitigation: The error mitigation to use.
59
+ endianness: Endianness of the resulting bit strings.
60
+ """
61
+ observable = observable if isinstance(observable, list) else [observable]
62
+ differentiable_expectation = DifferentiableExpectation(
63
+ backend=self.backend,
64
+ circuit=circuit,
65
+ observable=observable,
66
+ param_values=param_values,
67
+ state=state,
68
+ measurement=measurement,
69
+ noise=noise,
70
+ mitigation=mitigation,
71
+ endianness=endianness,
72
+ )
73
+
74
+ if self.diff_mode == DiffMode.AD:
75
+ expectation = differentiable_expectation.ad
76
+ elif self.diff_mode == DiffMode.ADJOINT:
77
+ expectation = differentiable_expectation.adjoint
78
+ else:
79
+ try:
80
+ fns = get_gpsr_fns()
81
+ psr_fn = fns[self.diff_mode]
82
+ except KeyError:
83
+ raise ValueError(f"{self.diff_mode} differentiation mode is not supported")
84
+ expectation = partial(differentiable_expectation.psr, psr_fn=psr_fn, **self.psr_args)
85
+ return expectation()
@@ -13,7 +13,7 @@ from torch.nn import Module
13
13
  from qadence.backend import Backend as QuantumBackend
14
14
  from qadence.backend import Converted, ConvertedCircuit, ConvertedObservable
15
15
  from qadence.backends.adjoint import AdjointExpectation
16
- from qadence.backends.utils import is_pyq_shape, param_dict, pyqify, validate_state
16
+ from qadence.backends.utils import infer_batchsize, is_pyq_shape, param_dict, pyqify, validate_state
17
17
  from qadence.blocks.abstract import AbstractBlock
18
18
  from qadence.blocks.primitive import PrimitiveBlock
19
19
  from qadence.blocks.utils import uuid_to_block, uuid_to_eigen
@@ -24,7 +24,6 @@ from qadence.mitigations import Mitigations
24
24
  from qadence.ml_tools import promote_to_tensor
25
25
  from qadence.noise import Noise
26
26
  from qadence.types import DiffMode, Endianness
27
- from qadence.utils import infer_batchsize
28
27
 
29
28
 
30
29
  class PSRExpectation(Function):
qadence/extensions.py CHANGED
@@ -6,15 +6,30 @@ from string import Template
6
6
  from qadence.backend import Backend
7
7
  from qadence.blocks.abstract import TAbstractBlock
8
8
  from qadence.logger import get_logger
9
- from qadence.types import BackendName, DiffMode
9
+ from qadence.types import BackendName, DiffMode, Engine
10
10
 
11
11
  backends_namespace = Template("qadence.backends.$name")
12
12
 
13
13
  logger = get_logger(__name__)
14
14
 
15
15
 
16
+ def _available_engines() -> dict:
17
+ """Returns a dictionary of currently installed, native qadence engines."""
18
+ res = {}
19
+ for engine in Engine.list():
20
+ module_path = f"qadence.engines.{engine}.differentiable_backend"
21
+ try:
22
+ module = importlib.import_module(module_path)
23
+ DifferentiableBackendCls = getattr(module, "DifferentiableBackend")
24
+ res[engine] = DifferentiableBackendCls
25
+ except (ImportError, ModuleNotFoundError):
26
+ pass
27
+ logger.info(f"Found engines: {res.keys()}")
28
+ return res
29
+
30
+
16
31
  def _available_backends() -> dict:
17
- """Fallback function for native Qadence available backends if extensions is not present."""
32
+ """Returns a dictionary of currently installed, native qadence backends."""
18
33
  res = {}
19
34
  for backend in BackendName.list():
20
35
  module_path = f"qadence.backends.{backend}.backend"
@@ -24,11 +39,12 @@ def _available_backends() -> dict:
24
39
  res[backend] = BackendCls
25
40
  except (ImportError, ModuleNotFoundError):
26
41
  pass
42
+ logger.info(f"Found backends: {res.keys()}")
27
43
  return res
28
44
 
29
45
 
30
46
  def _supported_gates(name: BackendName | str) -> list[TAbstractBlock]:
31
- """Fallback function for native Qadence backend supported gates if extensions is not present."""
47
+ """Returns a list of supported gates for the queried backend 'name'."""
32
48
  from qadence import operations
33
49
 
34
50
  name = str(BackendName(name).name.lower())
@@ -102,6 +118,7 @@ try:
102
118
  set_backend_config = getattr(module, "set_backend_config")
103
119
  except ModuleNotFoundError:
104
120
  available_backends = _available_backends
121
+ available_engines = _available_engines
105
122
  supported_gates = _supported_gates
106
123
  get_gpsr_fns = _gpsr_fns
107
124
  set_backend_config = _set_backend_config
@@ -178,9 +178,16 @@ class TransformedModule(torch.nn.Module):
178
178
  if isinstance(self.model, (QuantumModel, QNN)):
179
179
  if not isinstance(x, dict):
180
180
  x = self._format_to_dict(x)
181
- return {
182
- key: self._input_scaling * (val + self._input_shifting) for key, val in x.items()
183
- }
181
+ if self.in_features == 1:
182
+ return {
183
+ key: self._input_scaling * (val + self._input_shifting)
184
+ for key, val in x.items()
185
+ }
186
+ else:
187
+ return {
188
+ key: self._input_scaling[idx] * (val + self._input_shifting[idx])
189
+ for idx, (key, val) in enumerate(x.items())
190
+ }
184
191
 
185
192
  else:
186
193
  assert isinstance(self.model, torch.nn.Module) and isinstance(x, Tensor)
@@ -18,11 +18,14 @@ from qadence.backend import (
18
18
  )
19
19
  from qadence.backends.api import backend_factory, config_factory
20
20
  from qadence.blocks.abstract import AbstractBlock
21
+ from qadence.blocks.utils import chain, unique_parameters
21
22
  from qadence.circuit import QuantumCircuit
23
+ from qadence.engines.differentiable_backend import DifferentiableBackend
22
24
  from qadence.logger import get_logger
23
25
  from qadence.measurements import Measurements
24
26
  from qadence.mitigations import Mitigations
25
27
  from qadence.noise import Noise
28
+ from qadence.parameters import Parameter
26
29
  from qadence.types import DiffMode, Endianness
27
30
 
28
31
  logger = get_logger(__name__)
@@ -36,7 +39,7 @@ class QuantumModel(nn.Module):
36
39
  [here](/advanced_tutorials/custom-models.md).
37
40
  """
38
41
 
39
- backend: Backend
42
+ backend: Backend | DifferentiableBackend
40
43
  embedding_fn: Callable
41
44
  _params: nn.ParameterDict
42
45
  _circuit: ConvertedCircuit
@@ -77,7 +80,6 @@ class QuantumModel(nn.Module):
77
80
  f"The circuit should be of type '<class QuantumCircuit>'. Got {type(circuit)}."
78
81
  )
79
82
 
80
- self.inputs = [p for p in circuit.unique_parameters if not p.trainable and not p.is_number]
81
83
  if diff_mode is None:
82
84
  raise ValueError("`diff_mode` cannot be `None` in a `QuantumModel`.")
83
85
 
@@ -90,6 +92,15 @@ class QuantumModel(nn.Module):
90
92
  else:
91
93
  observable = [observable]
92
94
 
95
+ def _is_feature_param(p: Parameter) -> bool:
96
+ return not p.trainable and not p.is_number
97
+
98
+ if observable is None:
99
+ self.inputs = list(filter(_is_feature_param, circuit.unique_parameters))
100
+ else:
101
+ uparams = unique_parameters(chain(circuit.block, *observable))
102
+ self.inputs = list(filter(_is_feature_param, uparams))
103
+
93
104
  conv = self.backend.convert(circuit, observable)
94
105
  self.embedding_fn = conv.embedding_fn
95
106
  self._circuit = conv.circuit
qadence/parameters.py CHANGED
@@ -9,11 +9,11 @@ import sympy
9
9
  from sympy import *
10
10
  from sympy import Array, Basic, Expr, Symbol, sympify
11
11
  from sympy.physics.quantum.dagger import Dagger
12
- from sympytorch import SymPyModule
12
+ from sympytorch import SymPyModule as torchSympyModule
13
13
  from torch import Tensor, heaviside, no_grad, rand, tensor
14
14
 
15
15
  from qadence.logger import get_logger
16
- from qadence.types import TNumber
16
+ from qadence.types import DifferentiableExpression, Engine, TNumber
17
17
 
18
18
  # Modules to be automatically added to the qadence namespace
19
19
  __all__ = ["FeatureParameter", "Parameter", "VariationalParameter"]
@@ -190,23 +190,26 @@ def extract_original_param_entry(
190
190
  return param if not param.is_number else evaluate(param)
191
191
 
192
192
 
193
- def torchify(expr: Expr) -> SymPyModule:
194
- """
195
- Arguments:
193
+ def heaviside_func(x: Tensor, _: Any) -> Tensor:
194
+ with no_grad():
195
+ res = heaviside(x, tensor(0.5))
196
+ return res
196
197
 
197
- expr: An expression consisting of Parameters.
198
198
 
199
- Returns:
200
- A torchified, differentiable Expression.
201
- """
199
+ def torchify(expr: Expr) -> torchSympyModule:
200
+ extra_funcs = {sympy.core.numbers.ImaginaryUnit: 1.0j, sympy.Heaviside: heaviside_func}
201
+ return torchSympyModule(expressions=[sympy.N(expr)], extra_funcs=extra_funcs)
202
202
 
203
- def heaviside_func(x: Tensor, _: Any) -> Tensor:
204
- with no_grad():
205
- res = heaviside(x, tensor(0.5))
206
- return res
207
203
 
208
- extra_funcs = {sympy.core.numbers.ImaginaryUnit: 1.0j, sympy.Heaviside: heaviside_func}
209
- return SymPyModule(expressions=[sympy.N(expr)], extra_funcs=extra_funcs)
204
+ def make_differentiable(expr: Expr, engine: Engine = Engine.TORCH) -> DifferentiableExpression:
205
+ diff_expr: DifferentiableExpression
206
+ if engine == Engine.JAX:
207
+ from qadence.backends.jax_utils import jaxify
208
+
209
+ diff_expr = jaxify(expr)
210
+ else:
211
+ diff_expr = torchify(expr)
212
+ return diff_expr
210
213
 
211
214
 
212
215
  def sympy_to_numeric(expr: Basic) -> TNumber:
@@ -261,7 +264,7 @@ def evaluate(expr: Expr, values: dict = {}, as_torch: bool = False) -> TNumber |
261
264
  else:
262
265
  raise ValueError(f"No value provided for symbol {s.name}")
263
266
  if as_torch:
264
- res_value = torchify(expr)(**{s.name: tensor(v) for s, v in query.items()})
267
+ res_value = make_differentiable(expr)(**{s.name: tensor(v) for s, v in query.items()})
265
268
  else:
266
269
  res = expr.subs(query)
267
270
  res_value = sympy_to_numeric(res)
qadence/types.py CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import importlib
4
4
  from enum import Enum
5
- from typing import Iterable, Tuple, Union
5
+ from typing import Callable, Iterable, Tuple, Union
6
6
 
7
7
  import numpy as np
8
8
  import sympy
9
+ from numpy.typing import ArrayLike
9
10
  from torch import Tensor, pi
10
11
 
11
12
  TNumber = Union[int, float, complex]
@@ -197,6 +198,8 @@ class _BackendName(StrEnum):
197
198
  """The Braket backend."""
198
199
  PULSER = "pulser"
199
200
  """The Pulser backend."""
201
+ HORQRUX = "horqrux"
202
+ """The horqrux backend."""
200
203
 
201
204
 
202
205
  # If proprietary qadence_extensions is available, import the
@@ -386,3 +389,12 @@ class OpName(StrEnum):
386
389
  class ReadOutOptimization(StrEnum):
387
390
  MLE = "mle"
388
391
  CONSTRAINED = "constrained"
392
+
393
+
394
+ class Engine(StrEnum):
395
+ TORCH = "torch"
396
+ JAX = "jax"
397
+
398
+
399
+ ParamDictType = dict[str, ArrayLike]
400
+ DifferentiableExpression = Callable[..., ArrayLike]