qadence 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -0
- qadence/analog/__init__.py +4 -2
- qadence/analog/addressing.py +167 -0
- qadence/analog/constants.py +59 -0
- qadence/analog/device.py +82 -0
- qadence/analog/hamiltonian_terms.py +101 -0
- qadence/analog/parse_analog.py +120 -0
- qadence/backend.py +42 -12
- qadence/backends/__init__.py +1 -2
- qadence/backends/api.py +27 -9
- qadence/backends/braket/backend.py +3 -2
- qadence/backends/horqrux/__init__.py +5 -0
- qadence/backends/horqrux/backend.py +216 -0
- qadence/backends/horqrux/config.py +26 -0
- qadence/backends/horqrux/convert_ops.py +273 -0
- qadence/backends/jax_utils.py +45 -0
- qadence/backends/pulser/__init__.py +0 -1
- qadence/backends/pulser/backend.py +31 -15
- qadence/backends/pulser/config.py +19 -10
- qadence/backends/pulser/devices.py +57 -63
- qadence/backends/pulser/pulses.py +70 -12
- qadence/backends/pyqtorch/backend.py +4 -4
- qadence/backends/pyqtorch/config.py +18 -12
- qadence/backends/pyqtorch/convert_ops.py +15 -7
- qadence/backends/utils.py +5 -9
- qadence/blocks/abstract.py +5 -1
- qadence/blocks/analog.py +18 -9
- qadence/blocks/block_to_tensor.py +11 -0
- qadence/blocks/embedding.py +46 -24
- qadence/blocks/primitive.py +81 -9
- qadence/blocks/utils.py +20 -1
- qadence/circuit.py +3 -9
- qadence/constructors/__init__.py +4 -0
- qadence/constructors/feature_maps.py +84 -60
- qadence/constructors/hamiltonians.py +27 -98
- qadence/constructors/rydberg_feature_maps.py +113 -0
- qadence/divergences.py +12 -0
- qadence/engines/__init__.py +0 -0
- qadence/engines/differentiable_backend.py +152 -0
- qadence/engines/jax/__init__.py +8 -0
- qadence/engines/jax/differentiable_backend.py +73 -0
- qadence/engines/jax/differentiable_expectation.py +94 -0
- qadence/engines/torch/__init__.py +4 -0
- qadence/engines/torch/differentiable_backend.py +85 -0
- qadence/extensions.py +21 -9
- qadence/finitediff.py +47 -0
- qadence/mitigations/readout.py +92 -25
- qadence/ml_tools/models.py +10 -3
- qadence/models/qnn.py +88 -23
- qadence/models/quantum_model.py +13 -2
- qadence/operations.py +55 -70
- qadence/parameters.py +24 -13
- qadence/register.py +91 -43
- qadence/transpile/__init__.py +1 -0
- qadence/transpile/apply_fn.py +40 -0
- qadence/types.py +32 -2
- qadence/utils.py +35 -0
- {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/METADATA +22 -3
- {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/RECORD +62 -44
- {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/WHEEL +1 -1
- qadence/analog/interaction.py +0 -198
- qadence/analog/utils.py +0 -132
- /qadence/{backends/pytorch_wrapper.py → engines/torch/differentiable_expectation.py} +0 -0
- {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/licenses/LICENSE +0 -0
qadence/backends/__init__.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
4
|
from .api import backend_factory, config_factory
|
5
|
-
from .pytorch_wrapper import DifferentiableBackend
|
6
5
|
|
7
6
|
# Modules to be automatically added to the qadence namespace
|
8
|
-
__all__ = ["backend_factory", "config_factory"
|
7
|
+
__all__ = ["backend_factory", "config_factory"]
|
qadence/backends/api.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from qadence.backend import Backend, BackendConfiguration
|
4
|
-
from qadence.
|
5
|
-
from qadence.extensions import available_backends, set_backend_config
|
6
|
-
from qadence.types import BackendName, DiffMode
|
4
|
+
from qadence.engines.differentiable_backend import DifferentiableBackend
|
5
|
+
from qadence.extensions import available_backends, available_engines, set_backend_config
|
6
|
+
from qadence.types import BackendName, DiffMode, Engine
|
7
7
|
|
8
8
|
__all__ = ["backend_factory", "config_factory"]
|
9
9
|
|
@@ -14,13 +14,18 @@ def backend_factory(
|
|
14
14
|
configuration: BackendConfiguration | dict | None = None,
|
15
15
|
) -> Backend | DifferentiableBackend:
|
16
16
|
backend_inst: Backend | DifferentiableBackend
|
17
|
-
backend_name = BackendName(backend)
|
18
17
|
backends = available_backends()
|
19
|
-
|
18
|
+
try:
|
19
|
+
backend_name = BackendName(backend)
|
20
|
+
except ValueError:
|
21
|
+
raise NotImplementedError(f"The requested backend '{backend}' is not implemented.")
|
20
22
|
try:
|
21
23
|
BackendCls = backends[backend_name]
|
22
|
-
except
|
23
|
-
raise
|
24
|
+
except Exception as e:
|
25
|
+
raise ImportError(
|
26
|
+
f"The requested backend '{backend_name}' is either not installed\
|
27
|
+
or could not be imported due to {e}."
|
28
|
+
)
|
24
29
|
|
25
30
|
default_config = BackendCls.default_configuration()
|
26
31
|
if configuration is None:
|
@@ -44,9 +49,22 @@ def backend_factory(
|
|
44
49
|
|
45
50
|
# Set backend configurations which depend on the differentiation mode
|
46
51
|
set_backend_config(backend_inst, diff_mode)
|
47
|
-
|
52
|
+
# Wrap the quantum Backend in a DifferentiableBackend if a diff_mode is passed.
|
48
53
|
if diff_mode is not None:
|
49
|
-
|
54
|
+
try:
|
55
|
+
engine_name = Engine(backend_inst.engine)
|
56
|
+
except ValueError:
|
57
|
+
raise NotImplementedError(
|
58
|
+
f"The requested engine '{backend_inst.engine}' is not implemented."
|
59
|
+
)
|
60
|
+
try:
|
61
|
+
diff_backend_cls = available_engines()[engine_name]
|
62
|
+
backend_inst = diff_backend_cls(backend=backend_inst, diff_mode=DiffMode(diff_mode)) # type: ignore[arg-type]
|
63
|
+
except Exception as e:
|
64
|
+
raise ImportError(
|
65
|
+
f"The requested engine '{engine_name}' is either not installed\
|
66
|
+
or could not be imported due to {e}."
|
67
|
+
)
|
50
68
|
return backend_inst
|
51
69
|
|
52
70
|
|
@@ -23,7 +23,7 @@ from qadence.noise import Noise
|
|
23
23
|
from qadence.noise.protocols import apply_noise
|
24
24
|
from qadence.overlap import overlap_exact
|
25
25
|
from qadence.transpile import transpile
|
26
|
-
from qadence.types import BackendName
|
26
|
+
from qadence.types import BackendName, Engine
|
27
27
|
from qadence.utils import Endianness
|
28
28
|
|
29
29
|
from .config import Configuration, default_passes
|
@@ -55,6 +55,7 @@ class Backend(BackendInterface):
|
|
55
55
|
with_noise: bool = False
|
56
56
|
native_endianness: Endianness = Endianness.BIG
|
57
57
|
config: Configuration = field(default_factory=Configuration)
|
58
|
+
engine: Engine = Engine.TORCH
|
58
59
|
|
59
60
|
# braket specifics
|
60
61
|
# TODO: include it in the configuration?
|
@@ -87,7 +88,7 @@ class Backend(BackendInterface):
|
|
87
88
|
).squeeze(0)
|
88
89
|
return ConvertedObservable(native=native, abstract=obs, original=obs)
|
89
90
|
|
90
|
-
def
|
91
|
+
def _run(
|
91
92
|
self,
|
92
93
|
circuit: ConvertedCircuit,
|
93
94
|
param_values: dict[str, Tensor] = {},
|
@@ -0,0 +1,216 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections import Counter
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from horqrux.utils import prepare_state
|
10
|
+
from jax.typing import ArrayLike
|
11
|
+
|
12
|
+
from qadence.backend import Backend as BackendInterface
|
13
|
+
from qadence.backend import ConvertedCircuit, ConvertedObservable
|
14
|
+
from qadence.backends.jax_utils import (
|
15
|
+
tensor_to_jnp,
|
16
|
+
unhorqify,
|
17
|
+
uniform_batchsize,
|
18
|
+
)
|
19
|
+
from qadence.backends.utils import pyqify
|
20
|
+
from qadence.blocks import AbstractBlock
|
21
|
+
from qadence.circuit import QuantumCircuit
|
22
|
+
from qadence.measurements import Measurements
|
23
|
+
from qadence.mitigations import Mitigations
|
24
|
+
from qadence.noise import Noise
|
25
|
+
from qadence.transpile import flatten, scale_primitive_blocks_only, transpile
|
26
|
+
from qadence.types import BackendName, Endianness, Engine, ParamDictType
|
27
|
+
from qadence.utils import int_to_basis
|
28
|
+
|
29
|
+
from .config import Configuration, default_passes
|
30
|
+
from .convert_ops import HorqruxCircuit, convert_block, convert_observable
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass(frozen=True, eq=True)
|
34
|
+
class Backend(BackendInterface):
|
35
|
+
# set standard interface parameters
|
36
|
+
name: BackendName = BackendName.HORQRUX # type: ignore[assignment]
|
37
|
+
supports_ad: bool = True
|
38
|
+
supports_adjoint: bool = False
|
39
|
+
support_bp: bool = True
|
40
|
+
supports_native_psr: bool = False
|
41
|
+
is_remote: bool = False
|
42
|
+
with_measurements: bool = True
|
43
|
+
with_noise: bool = False
|
44
|
+
native_endianness: Endianness = Endianness.BIG
|
45
|
+
config: Configuration = field(default_factory=Configuration)
|
46
|
+
engine: Engine = Engine.JAX
|
47
|
+
|
48
|
+
def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit:
|
49
|
+
passes = self.config.transpilation_passes
|
50
|
+
if passes is None:
|
51
|
+
passes = default_passes(self.config)
|
52
|
+
|
53
|
+
original_circ = circuit
|
54
|
+
if len(passes) > 0:
|
55
|
+
circuit = transpile(*passes)(circuit)
|
56
|
+
ops = convert_block(circuit.block, n_qubits=circuit.n_qubits, config=self.config)
|
57
|
+
return ConvertedCircuit(
|
58
|
+
native=HorqruxCircuit(ops), abstract=circuit, original=original_circ
|
59
|
+
)
|
60
|
+
|
61
|
+
def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable:
|
62
|
+
transpilations = [
|
63
|
+
flatten,
|
64
|
+
scale_primitive_blocks_only,
|
65
|
+
]
|
66
|
+
block = transpile(*transpilations)(observable) # type: ignore[call-overload]
|
67
|
+
hq_obs = convert_observable(block, n_qubits=n_qubits, config=self.config)
|
68
|
+
return ConvertedObservable(native=hq_obs, abstract=block, original=observable)
|
69
|
+
|
70
|
+
def _run(
|
71
|
+
self,
|
72
|
+
circuit: ConvertedCircuit,
|
73
|
+
param_values: ParamDictType = {},
|
74
|
+
state: ArrayLike | None = None,
|
75
|
+
endianness: Endianness = Endianness.BIG,
|
76
|
+
horqify_state: bool = True,
|
77
|
+
unhorqify_state: bool = True,
|
78
|
+
) -> ArrayLike:
|
79
|
+
n_qubits = circuit.abstract.n_qubits
|
80
|
+
if state is None:
|
81
|
+
state = prepare_state(n_qubits, "0" * n_qubits)
|
82
|
+
else:
|
83
|
+
state = tensor_to_jnp(pyqify(state)) if horqify_state else state
|
84
|
+
state = circuit.native.forward(state, param_values)
|
85
|
+
if endianness != self.native_endianness:
|
86
|
+
state = jnp.reshape(state, (1, 2**n_qubits)) # batch_size is always 1
|
87
|
+
ls = list(range(2**n_qubits))
|
88
|
+
permute_ind = jnp.array([int(f"{num:0{n_qubits}b}"[::-1], 2) for num in ls])
|
89
|
+
state = state[:, permute_ind]
|
90
|
+
if unhorqify_state:
|
91
|
+
state = unhorqify(state)
|
92
|
+
return state
|
93
|
+
|
94
|
+
def run_dm(
|
95
|
+
self,
|
96
|
+
circuit: ConvertedCircuit,
|
97
|
+
noise: Noise,
|
98
|
+
param_values: ParamDictType = {},
|
99
|
+
state: ArrayLike | None = None,
|
100
|
+
endianness: Endianness = Endianness.BIG,
|
101
|
+
) -> ArrayLike:
|
102
|
+
raise NotImplementedError
|
103
|
+
|
104
|
+
def expectation(
|
105
|
+
self,
|
106
|
+
circuit: ConvertedCircuit,
|
107
|
+
observable: list[ConvertedObservable] | ConvertedObservable,
|
108
|
+
param_values: ParamDictType = {},
|
109
|
+
state: ArrayLike | None = None,
|
110
|
+
measurement: Measurements | None = None,
|
111
|
+
noise: Noise | None = None,
|
112
|
+
mitigation: Mitigations | None = None,
|
113
|
+
endianness: Endianness = Endianness.BIG,
|
114
|
+
) -> ArrayLike:
|
115
|
+
observable = observable if isinstance(observable, list) else [observable]
|
116
|
+
batch_size = max([arr.size for arr in param_values.values()])
|
117
|
+
n_obs = len(observable)
|
118
|
+
|
119
|
+
def _expectation(params: ParamDictType) -> ArrayLike:
|
120
|
+
out_state = self.run(
|
121
|
+
circuit, params, state, endianness, horqify_state=True, unhorqify_state=False
|
122
|
+
)
|
123
|
+
return jnp.array([o.native.forward(out_state, params) for o in observable])
|
124
|
+
|
125
|
+
if batch_size > 1: # We vmap for batch_size > 1
|
126
|
+
expvals = jax.vmap(_expectation, in_axes=({k: 0 for k in param_values.keys()},))(
|
127
|
+
uniform_batchsize(param_values)
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
expvals = _expectation(param_values)
|
131
|
+
if expvals.size > 1:
|
132
|
+
expvals = jnp.reshape(expvals, (batch_size, n_obs))
|
133
|
+
else:
|
134
|
+
expvals = jnp.squeeze(
|
135
|
+
expvals, 0
|
136
|
+
) # For the case of batch_size == n_obs == 1, we remove the dims
|
137
|
+
return expvals
|
138
|
+
|
139
|
+
def sample(
|
140
|
+
self,
|
141
|
+
circuit: ConvertedCircuit,
|
142
|
+
param_values: ParamDictType = {},
|
143
|
+
n_shots: int = 1,
|
144
|
+
state: ArrayLike | None = None,
|
145
|
+
noise: Noise | None = None,
|
146
|
+
mitigation: Mitigations | None = None,
|
147
|
+
endianness: Endianness = Endianness.BIG,
|
148
|
+
) -> list[Counter]:
|
149
|
+
"""Samples from a batch of discrete probability distributions.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
circuit: A ConvertedCircuit object holding the native PyQ Circuit.
|
153
|
+
param_values: A dict holding the embedded parameters which the native ciruit expects.
|
154
|
+
n_shots: The number of samples to generate per distribution.
|
155
|
+
state: The input state.
|
156
|
+
endianness (Endianness): The target endianness of the resulting samples.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
A list of Counter objects where each key represents a bitstring
|
160
|
+
and its value the number of times it has been sampled from the given wave function.
|
161
|
+
"""
|
162
|
+
if n_shots < 1:
|
163
|
+
raise ValueError("You can only call sample with n_shots>0.")
|
164
|
+
|
165
|
+
def _sample(
|
166
|
+
_probs: ArrayLike, n_shots: int, endianness: Endianness, n_qubits: int
|
167
|
+
) -> Counter:
|
168
|
+
_logits = jax.vmap(lambda _p: jnp.log(_p / (1 - _p)))(_probs)
|
169
|
+
|
170
|
+
def _smple(accumulator: ArrayLike, i: int) -> tuple[ArrayLike, None]:
|
171
|
+
accumulator = accumulator.at[i].set(
|
172
|
+
jax.random.categorical(jax.random.PRNGKey(i), _logits)
|
173
|
+
)
|
174
|
+
return accumulator, None
|
175
|
+
|
176
|
+
samples = jax.lax.scan(
|
177
|
+
_smple, jnp.empty_like(jnp.arange(n_shots)), jnp.arange(n_shots)
|
178
|
+
)[0]
|
179
|
+
return Counter(
|
180
|
+
{
|
181
|
+
int_to_basis(k=k, n_qubits=n_qubits, endianness=endianness): count.item()
|
182
|
+
for k, count in enumerate(jnp.bincount(samples))
|
183
|
+
if count > 0
|
184
|
+
}
|
185
|
+
)
|
186
|
+
|
187
|
+
wf = self.run(
|
188
|
+
circuit=circuit,
|
189
|
+
param_values=param_values,
|
190
|
+
state=state,
|
191
|
+
horqify_state=True,
|
192
|
+
unhorqify_state=False,
|
193
|
+
)
|
194
|
+
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
|
195
|
+
samples = [
|
196
|
+
_sample(
|
197
|
+
_probs=probs,
|
198
|
+
n_shots=n_shots,
|
199
|
+
endianness=endianness,
|
200
|
+
n_qubits=circuit.abstract.n_qubits,
|
201
|
+
),
|
202
|
+
]
|
203
|
+
|
204
|
+
return samples
|
205
|
+
|
206
|
+
def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any:
|
207
|
+
raise NotImplementedError
|
208
|
+
|
209
|
+
@staticmethod
|
210
|
+
def _overlap(bras: ArrayLike, kets: ArrayLike) -> ArrayLike:
|
211
|
+
# TODO
|
212
|
+
raise NotImplementedError
|
213
|
+
|
214
|
+
@staticmethod
|
215
|
+
def default_configuration() -> Configuration:
|
216
|
+
return Configuration()
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
from qadence.backend import BackendConfiguration
|
7
|
+
from qadence.logger import get_logger
|
8
|
+
from qadence.transpile import (
|
9
|
+
blockfn_to_circfn,
|
10
|
+
flatten,
|
11
|
+
scale_primitive_blocks_only,
|
12
|
+
)
|
13
|
+
|
14
|
+
logger = get_logger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def default_passes(config: Configuration) -> list[Callable]:
|
18
|
+
return [
|
19
|
+
flatten,
|
20
|
+
blockfn_to_circfn(scale_primitive_blocks_only),
|
21
|
+
]
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class Configuration(BackendConfiguration):
|
26
|
+
pass
|
@@ -0,0 +1,273 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from functools import reduce
|
5
|
+
from itertools import chain as flatten
|
6
|
+
from operator import add
|
7
|
+
from typing import Any, Callable, Dict
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
from horqrux.gates import NOT, H, I, Rx, Ry, Rz, X, Y, Z
|
11
|
+
from horqrux.ops import apply_gate
|
12
|
+
from horqrux.types import Gate
|
13
|
+
from horqrux.utils import overlap
|
14
|
+
from jax import Array
|
15
|
+
from jax.tree_util import register_pytree_node_class
|
16
|
+
|
17
|
+
from qadence.blocks import (
|
18
|
+
AbstractBlock,
|
19
|
+
AddBlock,
|
20
|
+
ChainBlock,
|
21
|
+
CompositeBlock,
|
22
|
+
KronBlock,
|
23
|
+
ParametricBlock,
|
24
|
+
PrimitiveBlock,
|
25
|
+
ScaleBlock,
|
26
|
+
)
|
27
|
+
from qadence.operations import CNOT, CRX, CRY, CRZ
|
28
|
+
from qadence.types import OpName, ParamDictType
|
29
|
+
|
30
|
+
from .config import Configuration
|
31
|
+
|
32
|
+
ops_map: Dict[str, Callable] = {
|
33
|
+
OpName.X: X,
|
34
|
+
OpName.Y: Y,
|
35
|
+
OpName.Z: Z,
|
36
|
+
OpName.H: H,
|
37
|
+
OpName.RX: Rx,
|
38
|
+
OpName.RY: Ry,
|
39
|
+
OpName.RZ: Rz,
|
40
|
+
OpName.CRX: Rx,
|
41
|
+
OpName.CRY: Ry,
|
42
|
+
OpName.CRZ: Rz,
|
43
|
+
OpName.CNOT: NOT,
|
44
|
+
OpName.I: I,
|
45
|
+
}
|
46
|
+
|
47
|
+
supported_gates = list(set(list(ops_map.keys())))
|
48
|
+
|
49
|
+
|
50
|
+
@register_pytree_node_class
|
51
|
+
@dataclass
|
52
|
+
class HorqruxCircuit:
|
53
|
+
operators: list[Gate] = field(default_factory=list)
|
54
|
+
|
55
|
+
def tree_flatten(self) -> tuple[tuple[list[Any]], tuple[()]]:
|
56
|
+
children = (self.operators,)
|
57
|
+
aux_data = ()
|
58
|
+
return (children, aux_data)
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
|
62
|
+
return cls(*children, *aux_data)
|
63
|
+
|
64
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
65
|
+
for op in self.operators:
|
66
|
+
state = op.forward(state, values)
|
67
|
+
return state
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class HorqruxObservable(HorqruxCircuit):
|
72
|
+
def __init__(self, operators: list[Gate]):
|
73
|
+
super().__init__(operators=operators)
|
74
|
+
|
75
|
+
def _forward(self, state: Array, values: ParamDictType) -> Array:
|
76
|
+
for op in self.operators:
|
77
|
+
state = op.forward(state, values)
|
78
|
+
return state
|
79
|
+
|
80
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
81
|
+
return overlap(state, self._forward(state, values))
|
82
|
+
|
83
|
+
|
84
|
+
def convert_observable(
|
85
|
+
block: AbstractBlock, n_qubits: int, config: Configuration
|
86
|
+
) -> HorqruxObservable:
|
87
|
+
_ops = convert_block(block, n_qubits, config)
|
88
|
+
return HorqruxObservable(_ops)
|
89
|
+
|
90
|
+
|
91
|
+
def convert_block(
|
92
|
+
block: AbstractBlock, n_qubits: int = None, config: Configuration = Configuration()
|
93
|
+
) -> list:
|
94
|
+
if n_qubits is None:
|
95
|
+
n_qubits = max(block.qubit_support) + 1
|
96
|
+
ops = []
|
97
|
+
if isinstance(block, CompositeBlock):
|
98
|
+
ops = list(flatten(*(convert_block(b, n_qubits, config) for b in block.blocks)))
|
99
|
+
if isinstance(block, AddBlock):
|
100
|
+
ops = [HorqAddGate(ops)]
|
101
|
+
elif isinstance(block, ChainBlock):
|
102
|
+
ops = [HorqruxCircuit(ops)]
|
103
|
+
elif isinstance(block, KronBlock):
|
104
|
+
if all(
|
105
|
+
[
|
106
|
+
isinstance(b, ParametricBlock) and not isinstance(b, ScaleBlock)
|
107
|
+
for b in block.blocks
|
108
|
+
]
|
109
|
+
):
|
110
|
+
param_names = [config.get_param_name(b)[0] for b in block.blocks if b.is_parametric]
|
111
|
+
ops = [
|
112
|
+
HorqKronParametric(
|
113
|
+
gates=[ops_map[b.name] for b in block.blocks],
|
114
|
+
target=[b.qubit_support[0] for b in block.blocks],
|
115
|
+
param_names=param_names,
|
116
|
+
)
|
117
|
+
]
|
118
|
+
|
119
|
+
elif all([b.name == "CNOT" for b in block.blocks]):
|
120
|
+
ops = [
|
121
|
+
HorqKronCNOT(
|
122
|
+
gates=[ops_map[b.name] for b in block.blocks],
|
123
|
+
target=[b.qubit_support[1] for b in block.blocks],
|
124
|
+
control=[b.qubit_support[0] for b in block.blocks],
|
125
|
+
)
|
126
|
+
]
|
127
|
+
else:
|
128
|
+
ops = [HorqruxCircuit(ops)]
|
129
|
+
|
130
|
+
elif isinstance(block, CNOT):
|
131
|
+
native_op = ops_map[block.name]
|
132
|
+
ops = [
|
133
|
+
HorqCNOTGate(native_op, block.qubit_support[0], block.qubit_support[1])
|
134
|
+
] # in horqrux target and control are swapped
|
135
|
+
|
136
|
+
elif isinstance(block, (CRX, CRY, CRZ)):
|
137
|
+
native_op = ops_map[block.name]
|
138
|
+
param_name = config.get_param_name(block)[0]
|
139
|
+
|
140
|
+
ops = [
|
141
|
+
HorqParametricGate(
|
142
|
+
gate=native_op,
|
143
|
+
qubit=block.qubit_support[1],
|
144
|
+
parameter_name=param_name,
|
145
|
+
control=block.qubit_support[0],
|
146
|
+
name=block.name,
|
147
|
+
)
|
148
|
+
]
|
149
|
+
elif isinstance(block, ScaleBlock):
|
150
|
+
op = convert_block(block.block, n_qubits, config=config)[0]
|
151
|
+
param_name = config.get_param_name(block)[0]
|
152
|
+
ops = [HorqScaleGate(op, param_name)]
|
153
|
+
|
154
|
+
elif isinstance(block, ParametricBlock):
|
155
|
+
native_op = ops_map[block.name]
|
156
|
+
if len(block.parameters._uuid_dict) > 1:
|
157
|
+
raise NotImplementedError("Only single-parameter operations are supported.")
|
158
|
+
param_name = config.get_param_name(block)[0]
|
159
|
+
|
160
|
+
ops = [
|
161
|
+
HorqParametricGate(
|
162
|
+
gate=native_op,
|
163
|
+
qubit=block.qubit_support[0],
|
164
|
+
parameter_name=param_name,
|
165
|
+
)
|
166
|
+
]
|
167
|
+
|
168
|
+
elif isinstance(block, PrimitiveBlock):
|
169
|
+
native_op = ops_map[block.name]
|
170
|
+
qubit = block.qubit_support[0]
|
171
|
+
ops = [HorqPrimitiveGate(gate=native_op, qubit=qubit, name=block.name)]
|
172
|
+
|
173
|
+
else:
|
174
|
+
raise NotImplementedError(f"Non-supported operation of type {type(block)}.")
|
175
|
+
|
176
|
+
return ops
|
177
|
+
|
178
|
+
|
179
|
+
class HorqPrimitiveGate:
|
180
|
+
def __init__(self, gate: Gate, qubit: int, name: str):
|
181
|
+
self.gates: Gate = gate
|
182
|
+
self.target = qubit
|
183
|
+
self.name = name
|
184
|
+
|
185
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
186
|
+
return apply_gate(state, self.gates(self.target))
|
187
|
+
|
188
|
+
def __repr__(self) -> str:
|
189
|
+
return self.name + f"(target={self.target})"
|
190
|
+
|
191
|
+
|
192
|
+
class HorqCNOTGate:
|
193
|
+
def __init__(self, gate: Gate, control: int, target: int):
|
194
|
+
self.gates: Callable = gate
|
195
|
+
self.control: int = control
|
196
|
+
self.target: int = target
|
197
|
+
|
198
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
199
|
+
return apply_gate(state, self.gates(self.target, self.control))
|
200
|
+
|
201
|
+
|
202
|
+
class HorqKronParametric:
|
203
|
+
def __init__(self, gates: list[Gate], param_names: list[str], target: list[int]):
|
204
|
+
self.operators: list[Gate] = gates
|
205
|
+
self.target: list[int] = target
|
206
|
+
self.param_names: list[str] = param_names
|
207
|
+
|
208
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
209
|
+
return apply_gate(
|
210
|
+
state,
|
211
|
+
tuple(
|
212
|
+
gate(values[param_name], target)
|
213
|
+
for gate, target, param_name in zip(self.operators, self.target, self.param_names)
|
214
|
+
),
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
class HorqKronCNOT(HorqruxCircuit):
|
219
|
+
def __init__(self, gates: list[Gate], target: list[int], control: list[int]):
|
220
|
+
self.operators: list[Gate] = gates
|
221
|
+
self.target: list[int] = target
|
222
|
+
self.control: list[int] = control
|
223
|
+
|
224
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
225
|
+
return apply_gate(
|
226
|
+
state,
|
227
|
+
tuple(
|
228
|
+
gate(target, control)
|
229
|
+
for gate, target, control in zip(self.operators, self.target, self.control)
|
230
|
+
),
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
class HorqParametricGate:
|
235
|
+
def __init__(
|
236
|
+
self, gate: Gate, qubit: int, parameter_name: str, control: int = None, name: str = ""
|
237
|
+
):
|
238
|
+
self.gates: Callable = gate
|
239
|
+
self.target: int = qubit
|
240
|
+
self.parameter: str = parameter_name
|
241
|
+
self.control: int | None = control
|
242
|
+
self.name = name
|
243
|
+
|
244
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
245
|
+
val = jnp.array(values[self.parameter])
|
246
|
+
return apply_gate(state, self.gates(val, self.target, self.control))
|
247
|
+
|
248
|
+
def __repr__(self) -> str:
|
249
|
+
return (
|
250
|
+
self.name
|
251
|
+
+ f"(target={self.target}, parameter={self.parameter}, control={self.control})"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class HorqAddGate(HorqruxCircuit):
|
256
|
+
def __init__(self, operations: list[Gate]):
|
257
|
+
self.operators = operations
|
258
|
+
self.name = "Add"
|
259
|
+
|
260
|
+
def forward(self, state: Array, values: ParamDictType = {}) -> Array:
|
261
|
+
return reduce(add, (op.forward(state, values) for op in self.operators))
|
262
|
+
|
263
|
+
def __repr__(self) -> str:
|
264
|
+
return self.name + f"({self.operators})"
|
265
|
+
|
266
|
+
|
267
|
+
class HorqScaleGate:
|
268
|
+
def __init__(self, op: Gate, parameter_name: str):
|
269
|
+
self.op = op
|
270
|
+
self.parameter: str = parameter_name
|
271
|
+
|
272
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
273
|
+
return jnp.array(values[self.parameter]) * self.op.forward(state, values)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jax import Array, device_get
|
7
|
+
from sympy import Expr
|
8
|
+
from sympy2jax import SymbolicModule as JaxSympyModule
|
9
|
+
from torch import Tensor, cdouble, from_numpy
|
10
|
+
|
11
|
+
from qadence.types import ParamDictType
|
12
|
+
|
13
|
+
|
14
|
+
def jarr_to_tensor(arr: Array, dtype: Any = cdouble) -> Tensor:
|
15
|
+
return from_numpy(device_get(arr)).to(dtype=dtype)
|
16
|
+
|
17
|
+
|
18
|
+
def tensor_to_jnp(tensor: Tensor, dtype: Any = jnp.complex128) -> Array:
|
19
|
+
return (
|
20
|
+
jnp.array(tensor.numpy(), dtype=dtype)
|
21
|
+
if not tensor.requires_grad
|
22
|
+
else jnp.array(tensor.detach().numpy(), dtype=dtype)
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
def values_to_jax(param_values: dict[str, Tensor]) -> dict[str, Array]:
|
27
|
+
return {key: jnp.array(value.detach().numpy()) for key, value in param_values.items()}
|
28
|
+
|
29
|
+
|
30
|
+
def jaxify(expr: Expr) -> JaxSympyModule:
|
31
|
+
return JaxSympyModule(expr)
|
32
|
+
|
33
|
+
|
34
|
+
def unhorqify(state: Array) -> Array:
|
35
|
+
"""Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits)."""
|
36
|
+
return jnp.ravel(state)
|
37
|
+
|
38
|
+
|
39
|
+
def uniform_batchsize(param_values: ParamDictType) -> ParamDictType:
|
40
|
+
max_batch_size = max(p.size for p in param_values.values())
|
41
|
+
batched_values = {
|
42
|
+
k: (v if v.size == max_batch_size else v.repeat(max_batch_size))
|
43
|
+
for k, v in param_values.items()
|
44
|
+
}
|
45
|
+
return batched_values
|