qadence 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. qadence/__init__.py +1 -0
  2. qadence/analog/__init__.py +4 -2
  3. qadence/analog/addressing.py +167 -0
  4. qadence/analog/constants.py +59 -0
  5. qadence/analog/device.py +82 -0
  6. qadence/analog/hamiltonian_terms.py +101 -0
  7. qadence/analog/parse_analog.py +120 -0
  8. qadence/backend.py +42 -12
  9. qadence/backends/__init__.py +1 -2
  10. qadence/backends/api.py +27 -9
  11. qadence/backends/braket/backend.py +3 -2
  12. qadence/backends/horqrux/__init__.py +5 -0
  13. qadence/backends/horqrux/backend.py +216 -0
  14. qadence/backends/horqrux/config.py +26 -0
  15. qadence/backends/horqrux/convert_ops.py +273 -0
  16. qadence/backends/jax_utils.py +45 -0
  17. qadence/backends/pulser/__init__.py +0 -1
  18. qadence/backends/pulser/backend.py +31 -15
  19. qadence/backends/pulser/config.py +19 -10
  20. qadence/backends/pulser/devices.py +57 -63
  21. qadence/backends/pulser/pulses.py +70 -12
  22. qadence/backends/pyqtorch/backend.py +4 -4
  23. qadence/backends/pyqtorch/config.py +18 -12
  24. qadence/backends/pyqtorch/convert_ops.py +15 -7
  25. qadence/backends/utils.py +5 -9
  26. qadence/blocks/abstract.py +5 -1
  27. qadence/blocks/analog.py +18 -9
  28. qadence/blocks/block_to_tensor.py +11 -0
  29. qadence/blocks/embedding.py +46 -24
  30. qadence/blocks/primitive.py +81 -9
  31. qadence/blocks/utils.py +20 -1
  32. qadence/circuit.py +3 -9
  33. qadence/constructors/__init__.py +4 -0
  34. qadence/constructors/feature_maps.py +84 -60
  35. qadence/constructors/hamiltonians.py +27 -98
  36. qadence/constructors/rydberg_feature_maps.py +113 -0
  37. qadence/divergences.py +12 -0
  38. qadence/engines/__init__.py +0 -0
  39. qadence/engines/differentiable_backend.py +152 -0
  40. qadence/engines/jax/__init__.py +8 -0
  41. qadence/engines/jax/differentiable_backend.py +73 -0
  42. qadence/engines/jax/differentiable_expectation.py +94 -0
  43. qadence/engines/torch/__init__.py +4 -0
  44. qadence/engines/torch/differentiable_backend.py +85 -0
  45. qadence/extensions.py +21 -9
  46. qadence/finitediff.py +47 -0
  47. qadence/mitigations/readout.py +92 -25
  48. qadence/ml_tools/models.py +10 -3
  49. qadence/models/qnn.py +88 -23
  50. qadence/models/quantum_model.py +13 -2
  51. qadence/operations.py +55 -70
  52. qadence/parameters.py +24 -13
  53. qadence/register.py +91 -43
  54. qadence/transpile/__init__.py +1 -0
  55. qadence/transpile/apply_fn.py +40 -0
  56. qadence/types.py +32 -2
  57. qadence/utils.py +35 -0
  58. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/METADATA +22 -3
  59. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/RECORD +62 -44
  60. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/WHEEL +1 -1
  61. qadence/analog/interaction.py +0 -198
  62. qadence/analog/utils.py +0 -132
  63. /qadence/{backends/pytorch_wrapper.py → engines/torch/differentiable_expectation.py} +0 -0
  64. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,6 @@
2
2
  from __future__ import annotations
3
3
 
4
4
  from .api import backend_factory, config_factory
5
- from .pytorch_wrapper import DifferentiableBackend
6
5
 
7
6
  # Modules to be automatically added to the qadence namespace
8
- __all__ = ["backend_factory", "config_factory", "DifferentiableBackend"]
7
+ __all__ = ["backend_factory", "config_factory"]
qadence/backends/api.py CHANGED
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from qadence.backend import Backend, BackendConfiguration
4
- from qadence.backends.pytorch_wrapper import DifferentiableBackend
5
- from qadence.extensions import available_backends, set_backend_config
6
- from qadence.types import BackendName, DiffMode
4
+ from qadence.engines.differentiable_backend import DifferentiableBackend
5
+ from qadence.extensions import available_backends, available_engines, set_backend_config
6
+ from qadence.types import BackendName, DiffMode, Engine
7
7
 
8
8
  __all__ = ["backend_factory", "config_factory"]
9
9
 
@@ -14,13 +14,18 @@ def backend_factory(
14
14
  configuration: BackendConfiguration | dict | None = None,
15
15
  ) -> Backend | DifferentiableBackend:
16
16
  backend_inst: Backend | DifferentiableBackend
17
- backend_name = BackendName(backend)
18
17
  backends = available_backends()
19
-
18
+ try:
19
+ backend_name = BackendName(backend)
20
+ except ValueError:
21
+ raise NotImplementedError(f"The requested backend '{backend}' is not implemented.")
20
22
  try:
21
23
  BackendCls = backends[backend_name]
22
- except (KeyError, ValueError):
23
- raise NotImplementedError(f"The requested backend '{backend_name}' is not implemented.")
24
+ except Exception as e:
25
+ raise ImportError(
26
+ f"The requested backend '{backend_name}' is either not installed\
27
+ or could not be imported due to {e}."
28
+ )
24
29
 
25
30
  default_config = BackendCls.default_configuration()
26
31
  if configuration is None:
@@ -44,9 +49,22 @@ def backend_factory(
44
49
 
45
50
  # Set backend configurations which depend on the differentiation mode
46
51
  set_backend_config(backend_inst, diff_mode)
47
-
52
+ # Wrap the quantum Backend in a DifferentiableBackend if a diff_mode is passed.
48
53
  if diff_mode is not None:
49
- backend_inst = DifferentiableBackend(backend_inst, DiffMode(diff_mode))
54
+ try:
55
+ engine_name = Engine(backend_inst.engine)
56
+ except ValueError:
57
+ raise NotImplementedError(
58
+ f"The requested engine '{backend_inst.engine}' is not implemented."
59
+ )
60
+ try:
61
+ diff_backend_cls = available_engines()[engine_name]
62
+ backend_inst = diff_backend_cls(backend=backend_inst, diff_mode=DiffMode(diff_mode)) # type: ignore[arg-type]
63
+ except Exception as e:
64
+ raise ImportError(
65
+ f"The requested engine '{engine_name}' is either not installed\
66
+ or could not be imported due to {e}."
67
+ )
50
68
  return backend_inst
51
69
 
52
70
 
@@ -23,7 +23,7 @@ from qadence.noise import Noise
23
23
  from qadence.noise.protocols import apply_noise
24
24
  from qadence.overlap import overlap_exact
25
25
  from qadence.transpile import transpile
26
- from qadence.types import BackendName
26
+ from qadence.types import BackendName, Engine
27
27
  from qadence.utils import Endianness
28
28
 
29
29
  from .config import Configuration, default_passes
@@ -55,6 +55,7 @@ class Backend(BackendInterface):
55
55
  with_noise: bool = False
56
56
  native_endianness: Endianness = Endianness.BIG
57
57
  config: Configuration = field(default_factory=Configuration)
58
+ engine: Engine = Engine.TORCH
58
59
 
59
60
  # braket specifics
60
61
  # TODO: include it in the configuration?
@@ -87,7 +88,7 @@ class Backend(BackendInterface):
87
88
  ).squeeze(0)
88
89
  return ConvertedObservable(native=native, abstract=obs, original=obs)
89
90
 
90
- def run(
91
+ def _run(
91
92
  self,
92
93
  circuit: ConvertedCircuit,
93
94
  param_values: dict[str, Tensor] = {},
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .backend import Backend
4
+ from .config import Configuration
5
+ from .convert_ops import supported_gates
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from horqrux.utils import prepare_state
10
+ from jax.typing import ArrayLike
11
+
12
+ from qadence.backend import Backend as BackendInterface
13
+ from qadence.backend import ConvertedCircuit, ConvertedObservable
14
+ from qadence.backends.jax_utils import (
15
+ tensor_to_jnp,
16
+ unhorqify,
17
+ uniform_batchsize,
18
+ )
19
+ from qadence.backends.utils import pyqify
20
+ from qadence.blocks import AbstractBlock
21
+ from qadence.circuit import QuantumCircuit
22
+ from qadence.measurements import Measurements
23
+ from qadence.mitigations import Mitigations
24
+ from qadence.noise import Noise
25
+ from qadence.transpile import flatten, scale_primitive_blocks_only, transpile
26
+ from qadence.types import BackendName, Endianness, Engine, ParamDictType
27
+ from qadence.utils import int_to_basis
28
+
29
+ from .config import Configuration, default_passes
30
+ from .convert_ops import HorqruxCircuit, convert_block, convert_observable
31
+
32
+
33
+ @dataclass(frozen=True, eq=True)
34
+ class Backend(BackendInterface):
35
+ # set standard interface parameters
36
+ name: BackendName = BackendName.HORQRUX # type: ignore[assignment]
37
+ supports_ad: bool = True
38
+ supports_adjoint: bool = False
39
+ support_bp: bool = True
40
+ supports_native_psr: bool = False
41
+ is_remote: bool = False
42
+ with_measurements: bool = True
43
+ with_noise: bool = False
44
+ native_endianness: Endianness = Endianness.BIG
45
+ config: Configuration = field(default_factory=Configuration)
46
+ engine: Engine = Engine.JAX
47
+
48
+ def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit:
49
+ passes = self.config.transpilation_passes
50
+ if passes is None:
51
+ passes = default_passes(self.config)
52
+
53
+ original_circ = circuit
54
+ if len(passes) > 0:
55
+ circuit = transpile(*passes)(circuit)
56
+ ops = convert_block(circuit.block, n_qubits=circuit.n_qubits, config=self.config)
57
+ return ConvertedCircuit(
58
+ native=HorqruxCircuit(ops), abstract=circuit, original=original_circ
59
+ )
60
+
61
+ def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable:
62
+ transpilations = [
63
+ flatten,
64
+ scale_primitive_blocks_only,
65
+ ]
66
+ block = transpile(*transpilations)(observable) # type: ignore[call-overload]
67
+ hq_obs = convert_observable(block, n_qubits=n_qubits, config=self.config)
68
+ return ConvertedObservable(native=hq_obs, abstract=block, original=observable)
69
+
70
+ def _run(
71
+ self,
72
+ circuit: ConvertedCircuit,
73
+ param_values: ParamDictType = {},
74
+ state: ArrayLike | None = None,
75
+ endianness: Endianness = Endianness.BIG,
76
+ horqify_state: bool = True,
77
+ unhorqify_state: bool = True,
78
+ ) -> ArrayLike:
79
+ n_qubits = circuit.abstract.n_qubits
80
+ if state is None:
81
+ state = prepare_state(n_qubits, "0" * n_qubits)
82
+ else:
83
+ state = tensor_to_jnp(pyqify(state)) if horqify_state else state
84
+ state = circuit.native.forward(state, param_values)
85
+ if endianness != self.native_endianness:
86
+ state = jnp.reshape(state, (1, 2**n_qubits)) # batch_size is always 1
87
+ ls = list(range(2**n_qubits))
88
+ permute_ind = jnp.array([int(f"{num:0{n_qubits}b}"[::-1], 2) for num in ls])
89
+ state = state[:, permute_ind]
90
+ if unhorqify_state:
91
+ state = unhorqify(state)
92
+ return state
93
+
94
+ def run_dm(
95
+ self,
96
+ circuit: ConvertedCircuit,
97
+ noise: Noise,
98
+ param_values: ParamDictType = {},
99
+ state: ArrayLike | None = None,
100
+ endianness: Endianness = Endianness.BIG,
101
+ ) -> ArrayLike:
102
+ raise NotImplementedError
103
+
104
+ def expectation(
105
+ self,
106
+ circuit: ConvertedCircuit,
107
+ observable: list[ConvertedObservable] | ConvertedObservable,
108
+ param_values: ParamDictType = {},
109
+ state: ArrayLike | None = None,
110
+ measurement: Measurements | None = None,
111
+ noise: Noise | None = None,
112
+ mitigation: Mitigations | None = None,
113
+ endianness: Endianness = Endianness.BIG,
114
+ ) -> ArrayLike:
115
+ observable = observable if isinstance(observable, list) else [observable]
116
+ batch_size = max([arr.size for arr in param_values.values()])
117
+ n_obs = len(observable)
118
+
119
+ def _expectation(params: ParamDictType) -> ArrayLike:
120
+ out_state = self.run(
121
+ circuit, params, state, endianness, horqify_state=True, unhorqify_state=False
122
+ )
123
+ return jnp.array([o.native.forward(out_state, params) for o in observable])
124
+
125
+ if batch_size > 1: # We vmap for batch_size > 1
126
+ expvals = jax.vmap(_expectation, in_axes=({k: 0 for k in param_values.keys()},))(
127
+ uniform_batchsize(param_values)
128
+ )
129
+ else:
130
+ expvals = _expectation(param_values)
131
+ if expvals.size > 1:
132
+ expvals = jnp.reshape(expvals, (batch_size, n_obs))
133
+ else:
134
+ expvals = jnp.squeeze(
135
+ expvals, 0
136
+ ) # For the case of batch_size == n_obs == 1, we remove the dims
137
+ return expvals
138
+
139
+ def sample(
140
+ self,
141
+ circuit: ConvertedCircuit,
142
+ param_values: ParamDictType = {},
143
+ n_shots: int = 1,
144
+ state: ArrayLike | None = None,
145
+ noise: Noise | None = None,
146
+ mitigation: Mitigations | None = None,
147
+ endianness: Endianness = Endianness.BIG,
148
+ ) -> list[Counter]:
149
+ """Samples from a batch of discrete probability distributions.
150
+
151
+ Args:
152
+ circuit: A ConvertedCircuit object holding the native PyQ Circuit.
153
+ param_values: A dict holding the embedded parameters which the native ciruit expects.
154
+ n_shots: The number of samples to generate per distribution.
155
+ state: The input state.
156
+ endianness (Endianness): The target endianness of the resulting samples.
157
+
158
+ Returns:
159
+ A list of Counter objects where each key represents a bitstring
160
+ and its value the number of times it has been sampled from the given wave function.
161
+ """
162
+ if n_shots < 1:
163
+ raise ValueError("You can only call sample with n_shots>0.")
164
+
165
+ def _sample(
166
+ _probs: ArrayLike, n_shots: int, endianness: Endianness, n_qubits: int
167
+ ) -> Counter:
168
+ _logits = jax.vmap(lambda _p: jnp.log(_p / (1 - _p)))(_probs)
169
+
170
+ def _smple(accumulator: ArrayLike, i: int) -> tuple[ArrayLike, None]:
171
+ accumulator = accumulator.at[i].set(
172
+ jax.random.categorical(jax.random.PRNGKey(i), _logits)
173
+ )
174
+ return accumulator, None
175
+
176
+ samples = jax.lax.scan(
177
+ _smple, jnp.empty_like(jnp.arange(n_shots)), jnp.arange(n_shots)
178
+ )[0]
179
+ return Counter(
180
+ {
181
+ int_to_basis(k=k, n_qubits=n_qubits, endianness=endianness): count.item()
182
+ for k, count in enumerate(jnp.bincount(samples))
183
+ if count > 0
184
+ }
185
+ )
186
+
187
+ wf = self.run(
188
+ circuit=circuit,
189
+ param_values=param_values,
190
+ state=state,
191
+ horqify_state=True,
192
+ unhorqify_state=False,
193
+ )
194
+ probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
195
+ samples = [
196
+ _sample(
197
+ _probs=probs,
198
+ n_shots=n_shots,
199
+ endianness=endianness,
200
+ n_qubits=circuit.abstract.n_qubits,
201
+ ),
202
+ ]
203
+
204
+ return samples
205
+
206
+ def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any:
207
+ raise NotImplementedError
208
+
209
+ @staticmethod
210
+ def _overlap(bras: ArrayLike, kets: ArrayLike) -> ArrayLike:
211
+ # TODO
212
+ raise NotImplementedError
213
+
214
+ @staticmethod
215
+ def default_configuration() -> Configuration:
216
+ return Configuration()
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ from qadence.backend import BackendConfiguration
7
+ from qadence.logger import get_logger
8
+ from qadence.transpile import (
9
+ blockfn_to_circfn,
10
+ flatten,
11
+ scale_primitive_blocks_only,
12
+ )
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def default_passes(config: Configuration) -> list[Callable]:
18
+ return [
19
+ flatten,
20
+ blockfn_to_circfn(scale_primitive_blocks_only),
21
+ ]
22
+
23
+
24
+ @dataclass
25
+ class Configuration(BackendConfiguration):
26
+ pass
@@ -0,0 +1,273 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from functools import reduce
5
+ from itertools import chain as flatten
6
+ from operator import add
7
+ from typing import Any, Callable, Dict
8
+
9
+ import jax.numpy as jnp
10
+ from horqrux.gates import NOT, H, I, Rx, Ry, Rz, X, Y, Z
11
+ from horqrux.ops import apply_gate
12
+ from horqrux.types import Gate
13
+ from horqrux.utils import overlap
14
+ from jax import Array
15
+ from jax.tree_util import register_pytree_node_class
16
+
17
+ from qadence.blocks import (
18
+ AbstractBlock,
19
+ AddBlock,
20
+ ChainBlock,
21
+ CompositeBlock,
22
+ KronBlock,
23
+ ParametricBlock,
24
+ PrimitiveBlock,
25
+ ScaleBlock,
26
+ )
27
+ from qadence.operations import CNOT, CRX, CRY, CRZ
28
+ from qadence.types import OpName, ParamDictType
29
+
30
+ from .config import Configuration
31
+
32
+ ops_map: Dict[str, Callable] = {
33
+ OpName.X: X,
34
+ OpName.Y: Y,
35
+ OpName.Z: Z,
36
+ OpName.H: H,
37
+ OpName.RX: Rx,
38
+ OpName.RY: Ry,
39
+ OpName.RZ: Rz,
40
+ OpName.CRX: Rx,
41
+ OpName.CRY: Ry,
42
+ OpName.CRZ: Rz,
43
+ OpName.CNOT: NOT,
44
+ OpName.I: I,
45
+ }
46
+
47
+ supported_gates = list(set(list(ops_map.keys())))
48
+
49
+
50
+ @register_pytree_node_class
51
+ @dataclass
52
+ class HorqruxCircuit:
53
+ operators: list[Gate] = field(default_factory=list)
54
+
55
+ def tree_flatten(self) -> tuple[tuple[list[Any]], tuple[()]]:
56
+ children = (self.operators,)
57
+ aux_data = ()
58
+ return (children, aux_data)
59
+
60
+ @classmethod
61
+ def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
62
+ return cls(*children, *aux_data)
63
+
64
+ def forward(self, state: Array, values: ParamDictType) -> Array:
65
+ for op in self.operators:
66
+ state = op.forward(state, values)
67
+ return state
68
+
69
+
70
+ @dataclass
71
+ class HorqruxObservable(HorqruxCircuit):
72
+ def __init__(self, operators: list[Gate]):
73
+ super().__init__(operators=operators)
74
+
75
+ def _forward(self, state: Array, values: ParamDictType) -> Array:
76
+ for op in self.operators:
77
+ state = op.forward(state, values)
78
+ return state
79
+
80
+ def forward(self, state: Array, values: ParamDictType) -> Array:
81
+ return overlap(state, self._forward(state, values))
82
+
83
+
84
+ def convert_observable(
85
+ block: AbstractBlock, n_qubits: int, config: Configuration
86
+ ) -> HorqruxObservable:
87
+ _ops = convert_block(block, n_qubits, config)
88
+ return HorqruxObservable(_ops)
89
+
90
+
91
+ def convert_block(
92
+ block: AbstractBlock, n_qubits: int = None, config: Configuration = Configuration()
93
+ ) -> list:
94
+ if n_qubits is None:
95
+ n_qubits = max(block.qubit_support) + 1
96
+ ops = []
97
+ if isinstance(block, CompositeBlock):
98
+ ops = list(flatten(*(convert_block(b, n_qubits, config) for b in block.blocks)))
99
+ if isinstance(block, AddBlock):
100
+ ops = [HorqAddGate(ops)]
101
+ elif isinstance(block, ChainBlock):
102
+ ops = [HorqruxCircuit(ops)]
103
+ elif isinstance(block, KronBlock):
104
+ if all(
105
+ [
106
+ isinstance(b, ParametricBlock) and not isinstance(b, ScaleBlock)
107
+ for b in block.blocks
108
+ ]
109
+ ):
110
+ param_names = [config.get_param_name(b)[0] for b in block.blocks if b.is_parametric]
111
+ ops = [
112
+ HorqKronParametric(
113
+ gates=[ops_map[b.name] for b in block.blocks],
114
+ target=[b.qubit_support[0] for b in block.blocks],
115
+ param_names=param_names,
116
+ )
117
+ ]
118
+
119
+ elif all([b.name == "CNOT" for b in block.blocks]):
120
+ ops = [
121
+ HorqKronCNOT(
122
+ gates=[ops_map[b.name] for b in block.blocks],
123
+ target=[b.qubit_support[1] for b in block.blocks],
124
+ control=[b.qubit_support[0] for b in block.blocks],
125
+ )
126
+ ]
127
+ else:
128
+ ops = [HorqruxCircuit(ops)]
129
+
130
+ elif isinstance(block, CNOT):
131
+ native_op = ops_map[block.name]
132
+ ops = [
133
+ HorqCNOTGate(native_op, block.qubit_support[0], block.qubit_support[1])
134
+ ] # in horqrux target and control are swapped
135
+
136
+ elif isinstance(block, (CRX, CRY, CRZ)):
137
+ native_op = ops_map[block.name]
138
+ param_name = config.get_param_name(block)[0]
139
+
140
+ ops = [
141
+ HorqParametricGate(
142
+ gate=native_op,
143
+ qubit=block.qubit_support[1],
144
+ parameter_name=param_name,
145
+ control=block.qubit_support[0],
146
+ name=block.name,
147
+ )
148
+ ]
149
+ elif isinstance(block, ScaleBlock):
150
+ op = convert_block(block.block, n_qubits, config=config)[0]
151
+ param_name = config.get_param_name(block)[0]
152
+ ops = [HorqScaleGate(op, param_name)]
153
+
154
+ elif isinstance(block, ParametricBlock):
155
+ native_op = ops_map[block.name]
156
+ if len(block.parameters._uuid_dict) > 1:
157
+ raise NotImplementedError("Only single-parameter operations are supported.")
158
+ param_name = config.get_param_name(block)[0]
159
+
160
+ ops = [
161
+ HorqParametricGate(
162
+ gate=native_op,
163
+ qubit=block.qubit_support[0],
164
+ parameter_name=param_name,
165
+ )
166
+ ]
167
+
168
+ elif isinstance(block, PrimitiveBlock):
169
+ native_op = ops_map[block.name]
170
+ qubit = block.qubit_support[0]
171
+ ops = [HorqPrimitiveGate(gate=native_op, qubit=qubit, name=block.name)]
172
+
173
+ else:
174
+ raise NotImplementedError(f"Non-supported operation of type {type(block)}.")
175
+
176
+ return ops
177
+
178
+
179
+ class HorqPrimitiveGate:
180
+ def __init__(self, gate: Gate, qubit: int, name: str):
181
+ self.gates: Gate = gate
182
+ self.target = qubit
183
+ self.name = name
184
+
185
+ def forward(self, state: Array, values: ParamDictType) -> Array:
186
+ return apply_gate(state, self.gates(self.target))
187
+
188
+ def __repr__(self) -> str:
189
+ return self.name + f"(target={self.target})"
190
+
191
+
192
+ class HorqCNOTGate:
193
+ def __init__(self, gate: Gate, control: int, target: int):
194
+ self.gates: Callable = gate
195
+ self.control: int = control
196
+ self.target: int = target
197
+
198
+ def forward(self, state: Array, values: ParamDictType) -> Array:
199
+ return apply_gate(state, self.gates(self.target, self.control))
200
+
201
+
202
+ class HorqKronParametric:
203
+ def __init__(self, gates: list[Gate], param_names: list[str], target: list[int]):
204
+ self.operators: list[Gate] = gates
205
+ self.target: list[int] = target
206
+ self.param_names: list[str] = param_names
207
+
208
+ def forward(self, state: Array, values: ParamDictType) -> Array:
209
+ return apply_gate(
210
+ state,
211
+ tuple(
212
+ gate(values[param_name], target)
213
+ for gate, target, param_name in zip(self.operators, self.target, self.param_names)
214
+ ),
215
+ )
216
+
217
+
218
+ class HorqKronCNOT(HorqruxCircuit):
219
+ def __init__(self, gates: list[Gate], target: list[int], control: list[int]):
220
+ self.operators: list[Gate] = gates
221
+ self.target: list[int] = target
222
+ self.control: list[int] = control
223
+
224
+ def forward(self, state: Array, values: ParamDictType) -> Array:
225
+ return apply_gate(
226
+ state,
227
+ tuple(
228
+ gate(target, control)
229
+ for gate, target, control in zip(self.operators, self.target, self.control)
230
+ ),
231
+ )
232
+
233
+
234
+ class HorqParametricGate:
235
+ def __init__(
236
+ self, gate: Gate, qubit: int, parameter_name: str, control: int = None, name: str = ""
237
+ ):
238
+ self.gates: Callable = gate
239
+ self.target: int = qubit
240
+ self.parameter: str = parameter_name
241
+ self.control: int | None = control
242
+ self.name = name
243
+
244
+ def forward(self, state: Array, values: ParamDictType) -> Array:
245
+ val = jnp.array(values[self.parameter])
246
+ return apply_gate(state, self.gates(val, self.target, self.control))
247
+
248
+ def __repr__(self) -> str:
249
+ return (
250
+ self.name
251
+ + f"(target={self.target}, parameter={self.parameter}, control={self.control})"
252
+ )
253
+
254
+
255
+ class HorqAddGate(HorqruxCircuit):
256
+ def __init__(self, operations: list[Gate]):
257
+ self.operators = operations
258
+ self.name = "Add"
259
+
260
+ def forward(self, state: Array, values: ParamDictType = {}) -> Array:
261
+ return reduce(add, (op.forward(state, values) for op in self.operators))
262
+
263
+ def __repr__(self) -> str:
264
+ return self.name + f"({self.operators})"
265
+
266
+
267
+ class HorqScaleGate:
268
+ def __init__(self, op: Gate, parameter_name: str):
269
+ self.op = op
270
+ self.parameter: str = parameter_name
271
+
272
+ def forward(self, state: Array, values: ParamDictType) -> Array:
273
+ return jnp.array(values[self.parameter]) * self.op.forward(state, values)
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import jax.numpy as jnp
6
+ from jax import Array, device_get
7
+ from sympy import Expr
8
+ from sympy2jax import SymbolicModule as JaxSympyModule
9
+ from torch import Tensor, cdouble, from_numpy
10
+
11
+ from qadence.types import ParamDictType
12
+
13
+
14
+ def jarr_to_tensor(arr: Array, dtype: Any = cdouble) -> Tensor:
15
+ return from_numpy(device_get(arr)).to(dtype=dtype)
16
+
17
+
18
+ def tensor_to_jnp(tensor: Tensor, dtype: Any = jnp.complex128) -> Array:
19
+ return (
20
+ jnp.array(tensor.numpy(), dtype=dtype)
21
+ if not tensor.requires_grad
22
+ else jnp.array(tensor.detach().numpy(), dtype=dtype)
23
+ )
24
+
25
+
26
+ def values_to_jax(param_values: dict[str, Tensor]) -> dict[str, Array]:
27
+ return {key: jnp.array(value.detach().numpy()) for key, value in param_values.items()}
28
+
29
+
30
+ def jaxify(expr: Expr) -> JaxSympyModule:
31
+ return JaxSympyModule(expr)
32
+
33
+
34
+ def unhorqify(state: Array) -> Array:
35
+ """Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits)."""
36
+ return jnp.ravel(state)
37
+
38
+
39
+ def uniform_batchsize(param_values: ParamDictType) -> ParamDictType:
40
+ max_batch_size = max(p.size for p in param_values.values())
41
+ batched_values = {
42
+ k: (v if v.size == max_batch_size else v.repeat(max_batch_size))
43
+ for k, v in param_values.items()
44
+ }
45
+ return batched_values
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from .backend import Backend, Configuration
4
- from .devices import Device
5
4
  from .pulses import supported_gates