qadence 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -0
- qadence/backend.py +18 -14
- qadence/backends/__init__.py +1 -2
- qadence/backends/api.py +27 -9
- qadence/backends/braket/backend.py +2 -1
- qadence/backends/horqrux/__init__.py +5 -0
- qadence/backends/horqrux/backend.py +216 -0
- qadence/backends/horqrux/config.py +26 -0
- qadence/backends/horqrux/convert_ops.py +273 -0
- qadence/backends/jax_utils.py +45 -0
- qadence/backends/pulser/backend.py +2 -1
- qadence/backends/pyqtorch/backend.py +2 -1
- qadence/backends/pyqtorch/convert_ops.py +3 -3
- qadence/backends/utils.py +8 -3
- qadence/blocks/embedding.py +46 -24
- qadence/blocks/utils.py +20 -1
- qadence/circuit.py +3 -9
- qadence/engines/__init__.py +0 -0
- qadence/engines/differentiable_backend.py +152 -0
- qadence/engines/jax/__init__.py +8 -0
- qadence/engines/jax/differentiable_backend.py +73 -0
- qadence/engines/jax/differentiable_expectation.py +94 -0
- qadence/engines/torch/__init__.py +4 -0
- qadence/engines/torch/differentiable_backend.py +85 -0
- qadence/{backends/pytorch_wrapper.py → engines/torch/differentiable_expectation.py} +1 -2
- qadence/extensions.py +20 -3
- qadence/ml_tools/models.py +10 -3
- qadence/models/quantum_model.py +13 -2
- qadence/parameters.py +19 -16
- qadence/types.py +13 -1
- {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/METADATA +21 -2
- {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/RECORD +34 -22
- {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/WHEEL +1 -1
- {qadence-1.2.0.dist-info → qadence-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,273 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from functools import reduce
|
5
|
+
from itertools import chain as flatten
|
6
|
+
from operator import add
|
7
|
+
from typing import Any, Callable, Dict
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
from horqrux.gates import NOT, H, I, Rx, Ry, Rz, X, Y, Z
|
11
|
+
from horqrux.ops import apply_gate
|
12
|
+
from horqrux.types import Gate
|
13
|
+
from horqrux.utils import overlap
|
14
|
+
from jax import Array
|
15
|
+
from jax.tree_util import register_pytree_node_class
|
16
|
+
|
17
|
+
from qadence.blocks import (
|
18
|
+
AbstractBlock,
|
19
|
+
AddBlock,
|
20
|
+
ChainBlock,
|
21
|
+
CompositeBlock,
|
22
|
+
KronBlock,
|
23
|
+
ParametricBlock,
|
24
|
+
PrimitiveBlock,
|
25
|
+
ScaleBlock,
|
26
|
+
)
|
27
|
+
from qadence.operations import CNOT, CRX, CRY, CRZ
|
28
|
+
from qadence.types import OpName, ParamDictType
|
29
|
+
|
30
|
+
from .config import Configuration
|
31
|
+
|
32
|
+
ops_map: Dict[str, Callable] = {
|
33
|
+
OpName.X: X,
|
34
|
+
OpName.Y: Y,
|
35
|
+
OpName.Z: Z,
|
36
|
+
OpName.H: H,
|
37
|
+
OpName.RX: Rx,
|
38
|
+
OpName.RY: Ry,
|
39
|
+
OpName.RZ: Rz,
|
40
|
+
OpName.CRX: Rx,
|
41
|
+
OpName.CRY: Ry,
|
42
|
+
OpName.CRZ: Rz,
|
43
|
+
OpName.CNOT: NOT,
|
44
|
+
OpName.I: I,
|
45
|
+
}
|
46
|
+
|
47
|
+
supported_gates = list(set(list(ops_map.keys())))
|
48
|
+
|
49
|
+
|
50
|
+
@register_pytree_node_class
|
51
|
+
@dataclass
|
52
|
+
class HorqruxCircuit:
|
53
|
+
operators: list[Gate] = field(default_factory=list)
|
54
|
+
|
55
|
+
def tree_flatten(self) -> tuple[tuple[list[Any]], tuple[()]]:
|
56
|
+
children = (self.operators,)
|
57
|
+
aux_data = ()
|
58
|
+
return (children, aux_data)
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
|
62
|
+
return cls(*children, *aux_data)
|
63
|
+
|
64
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
65
|
+
for op in self.operators:
|
66
|
+
state = op.forward(state, values)
|
67
|
+
return state
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class HorqruxObservable(HorqruxCircuit):
|
72
|
+
def __init__(self, operators: list[Gate]):
|
73
|
+
super().__init__(operators=operators)
|
74
|
+
|
75
|
+
def _forward(self, state: Array, values: ParamDictType) -> Array:
|
76
|
+
for op in self.operators:
|
77
|
+
state = op.forward(state, values)
|
78
|
+
return state
|
79
|
+
|
80
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
81
|
+
return overlap(state, self._forward(state, values))
|
82
|
+
|
83
|
+
|
84
|
+
def convert_observable(
|
85
|
+
block: AbstractBlock, n_qubits: int, config: Configuration
|
86
|
+
) -> HorqruxObservable:
|
87
|
+
_ops = convert_block(block, n_qubits, config)
|
88
|
+
return HorqruxObservable(_ops)
|
89
|
+
|
90
|
+
|
91
|
+
def convert_block(
|
92
|
+
block: AbstractBlock, n_qubits: int = None, config: Configuration = Configuration()
|
93
|
+
) -> list:
|
94
|
+
if n_qubits is None:
|
95
|
+
n_qubits = max(block.qubit_support) + 1
|
96
|
+
ops = []
|
97
|
+
if isinstance(block, CompositeBlock):
|
98
|
+
ops = list(flatten(*(convert_block(b, n_qubits, config) for b in block.blocks)))
|
99
|
+
if isinstance(block, AddBlock):
|
100
|
+
ops = [HorqAddGate(ops)]
|
101
|
+
elif isinstance(block, ChainBlock):
|
102
|
+
ops = [HorqruxCircuit(ops)]
|
103
|
+
elif isinstance(block, KronBlock):
|
104
|
+
if all(
|
105
|
+
[
|
106
|
+
isinstance(b, ParametricBlock) and not isinstance(b, ScaleBlock)
|
107
|
+
for b in block.blocks
|
108
|
+
]
|
109
|
+
):
|
110
|
+
param_names = [config.get_param_name(b)[0] for b in block.blocks if b.is_parametric]
|
111
|
+
ops = [
|
112
|
+
HorqKronParametric(
|
113
|
+
gates=[ops_map[b.name] for b in block.blocks],
|
114
|
+
target=[b.qubit_support[0] for b in block.blocks],
|
115
|
+
param_names=param_names,
|
116
|
+
)
|
117
|
+
]
|
118
|
+
|
119
|
+
elif all([b.name == "CNOT" for b in block.blocks]):
|
120
|
+
ops = [
|
121
|
+
HorqKronCNOT(
|
122
|
+
gates=[ops_map[b.name] for b in block.blocks],
|
123
|
+
target=[b.qubit_support[1] for b in block.blocks],
|
124
|
+
control=[b.qubit_support[0] for b in block.blocks],
|
125
|
+
)
|
126
|
+
]
|
127
|
+
else:
|
128
|
+
ops = [HorqruxCircuit(ops)]
|
129
|
+
|
130
|
+
elif isinstance(block, CNOT):
|
131
|
+
native_op = ops_map[block.name]
|
132
|
+
ops = [
|
133
|
+
HorqCNOTGate(native_op, block.qubit_support[0], block.qubit_support[1])
|
134
|
+
] # in horqrux target and control are swapped
|
135
|
+
|
136
|
+
elif isinstance(block, (CRX, CRY, CRZ)):
|
137
|
+
native_op = ops_map[block.name]
|
138
|
+
param_name = config.get_param_name(block)[0]
|
139
|
+
|
140
|
+
ops = [
|
141
|
+
HorqParametricGate(
|
142
|
+
gate=native_op,
|
143
|
+
qubit=block.qubit_support[1],
|
144
|
+
parameter_name=param_name,
|
145
|
+
control=block.qubit_support[0],
|
146
|
+
name=block.name,
|
147
|
+
)
|
148
|
+
]
|
149
|
+
elif isinstance(block, ScaleBlock):
|
150
|
+
op = convert_block(block.block, n_qubits, config=config)[0]
|
151
|
+
param_name = config.get_param_name(block)[0]
|
152
|
+
ops = [HorqScaleGate(op, param_name)]
|
153
|
+
|
154
|
+
elif isinstance(block, ParametricBlock):
|
155
|
+
native_op = ops_map[block.name]
|
156
|
+
if len(block.parameters._uuid_dict) > 1:
|
157
|
+
raise NotImplementedError("Only single-parameter operations are supported.")
|
158
|
+
param_name = config.get_param_name(block)[0]
|
159
|
+
|
160
|
+
ops = [
|
161
|
+
HorqParametricGate(
|
162
|
+
gate=native_op,
|
163
|
+
qubit=block.qubit_support[0],
|
164
|
+
parameter_name=param_name,
|
165
|
+
)
|
166
|
+
]
|
167
|
+
|
168
|
+
elif isinstance(block, PrimitiveBlock):
|
169
|
+
native_op = ops_map[block.name]
|
170
|
+
qubit = block.qubit_support[0]
|
171
|
+
ops = [HorqPrimitiveGate(gate=native_op, qubit=qubit, name=block.name)]
|
172
|
+
|
173
|
+
else:
|
174
|
+
raise NotImplementedError(f"Non-supported operation of type {type(block)}.")
|
175
|
+
|
176
|
+
return ops
|
177
|
+
|
178
|
+
|
179
|
+
class HorqPrimitiveGate:
|
180
|
+
def __init__(self, gate: Gate, qubit: int, name: str):
|
181
|
+
self.gates: Gate = gate
|
182
|
+
self.target = qubit
|
183
|
+
self.name = name
|
184
|
+
|
185
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
186
|
+
return apply_gate(state, self.gates(self.target))
|
187
|
+
|
188
|
+
def __repr__(self) -> str:
|
189
|
+
return self.name + f"(target={self.target})"
|
190
|
+
|
191
|
+
|
192
|
+
class HorqCNOTGate:
|
193
|
+
def __init__(self, gate: Gate, control: int, target: int):
|
194
|
+
self.gates: Callable = gate
|
195
|
+
self.control: int = control
|
196
|
+
self.target: int = target
|
197
|
+
|
198
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
199
|
+
return apply_gate(state, self.gates(self.target, self.control))
|
200
|
+
|
201
|
+
|
202
|
+
class HorqKronParametric:
|
203
|
+
def __init__(self, gates: list[Gate], param_names: list[str], target: list[int]):
|
204
|
+
self.operators: list[Gate] = gates
|
205
|
+
self.target: list[int] = target
|
206
|
+
self.param_names: list[str] = param_names
|
207
|
+
|
208
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
209
|
+
return apply_gate(
|
210
|
+
state,
|
211
|
+
tuple(
|
212
|
+
gate(values[param_name], target)
|
213
|
+
for gate, target, param_name in zip(self.operators, self.target, self.param_names)
|
214
|
+
),
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
class HorqKronCNOT(HorqruxCircuit):
|
219
|
+
def __init__(self, gates: list[Gate], target: list[int], control: list[int]):
|
220
|
+
self.operators: list[Gate] = gates
|
221
|
+
self.target: list[int] = target
|
222
|
+
self.control: list[int] = control
|
223
|
+
|
224
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
225
|
+
return apply_gate(
|
226
|
+
state,
|
227
|
+
tuple(
|
228
|
+
gate(target, control)
|
229
|
+
for gate, target, control in zip(self.operators, self.target, self.control)
|
230
|
+
),
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
class HorqParametricGate:
|
235
|
+
def __init__(
|
236
|
+
self, gate: Gate, qubit: int, parameter_name: str, control: int = None, name: str = ""
|
237
|
+
):
|
238
|
+
self.gates: Callable = gate
|
239
|
+
self.target: int = qubit
|
240
|
+
self.parameter: str = parameter_name
|
241
|
+
self.control: int | None = control
|
242
|
+
self.name = name
|
243
|
+
|
244
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
245
|
+
val = jnp.array(values[self.parameter])
|
246
|
+
return apply_gate(state, self.gates(val, self.target, self.control))
|
247
|
+
|
248
|
+
def __repr__(self) -> str:
|
249
|
+
return (
|
250
|
+
self.name
|
251
|
+
+ f"(target={self.target}, parameter={self.parameter}, control={self.control})"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class HorqAddGate(HorqruxCircuit):
|
256
|
+
def __init__(self, operations: list[Gate]):
|
257
|
+
self.operators = operations
|
258
|
+
self.name = "Add"
|
259
|
+
|
260
|
+
def forward(self, state: Array, values: ParamDictType = {}) -> Array:
|
261
|
+
return reduce(add, (op.forward(state, values) for op in self.operators))
|
262
|
+
|
263
|
+
def __repr__(self) -> str:
|
264
|
+
return self.name + f"({self.operators})"
|
265
|
+
|
266
|
+
|
267
|
+
class HorqScaleGate:
|
268
|
+
def __init__(self, op: Gate, parameter_name: str):
|
269
|
+
self.op = op
|
270
|
+
self.parameter: str = parameter_name
|
271
|
+
|
272
|
+
def forward(self, state: Array, values: ParamDictType) -> Array:
|
273
|
+
return jnp.array(values[self.parameter]) * self.op.forward(state, values)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jax import Array, device_get
|
7
|
+
from sympy import Expr
|
8
|
+
from sympy2jax import SymbolicModule as JaxSympyModule
|
9
|
+
from torch import Tensor, cdouble, from_numpy
|
10
|
+
|
11
|
+
from qadence.types import ParamDictType
|
12
|
+
|
13
|
+
|
14
|
+
def jarr_to_tensor(arr: Array, dtype: Any = cdouble) -> Tensor:
|
15
|
+
return from_numpy(device_get(arr)).to(dtype=dtype)
|
16
|
+
|
17
|
+
|
18
|
+
def tensor_to_jnp(tensor: Tensor, dtype: Any = jnp.complex128) -> Array:
|
19
|
+
return (
|
20
|
+
jnp.array(tensor.numpy(), dtype=dtype)
|
21
|
+
if not tensor.requires_grad
|
22
|
+
else jnp.array(tensor.detach().numpy(), dtype=dtype)
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
def values_to_jax(param_values: dict[str, Tensor]) -> dict[str, Array]:
|
27
|
+
return {key: jnp.array(value.detach().numpy()) for key, value in param_values.items()}
|
28
|
+
|
29
|
+
|
30
|
+
def jaxify(expr: Expr) -> JaxSympyModule:
|
31
|
+
return JaxSympyModule(expr)
|
32
|
+
|
33
|
+
|
34
|
+
def unhorqify(state: Array) -> Array:
|
35
|
+
"""Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits)."""
|
36
|
+
return jnp.ravel(state)
|
37
|
+
|
38
|
+
|
39
|
+
def uniform_batchsize(param_values: ParamDictType) -> ParamDictType:
|
40
|
+
max_batch_size = max(p.size for p in param_values.values())
|
41
|
+
batched_values = {
|
42
|
+
k: (v if v.size == max_batch_size else v.repeat(max_batch_size))
|
43
|
+
for k, v in param_values.items()
|
44
|
+
}
|
45
|
+
return batched_values
|
@@ -28,7 +28,7 @@ from qadence.noise.protocols import apply_noise
|
|
28
28
|
from qadence.overlap import overlap_exact
|
29
29
|
from qadence.register import Register
|
30
30
|
from qadence.transpile import transpile
|
31
|
-
from qadence.types import BackendName, DeviceType, Endianness
|
31
|
+
from qadence.types import BackendName, DeviceType, Endianness, Engine
|
32
32
|
|
33
33
|
from .channels import GLOBAL_CHANNEL, LOCAL_CHANNEL
|
34
34
|
from .cloud import get_client
|
@@ -159,6 +159,7 @@ class Backend(BackendInterface):
|
|
159
159
|
with_noise: bool = False
|
160
160
|
native_endianness: Endianness = Endianness.BIG
|
161
161
|
config: Configuration = field(default_factory=Configuration)
|
162
|
+
engine: Engine = Engine.TORCH
|
162
163
|
|
163
164
|
def circuit(self, circ: QuantumCircuit) -> Sequence:
|
164
165
|
passes = self.config.transpilation_passes
|
@@ -30,7 +30,7 @@ from qadence.transpile import (
|
|
30
30
|
scale_primitive_blocks_only,
|
31
31
|
transpile,
|
32
32
|
)
|
33
|
-
from qadence.types import BackendName, Endianness
|
33
|
+
from qadence.types import BackendName, Endianness, Engine
|
34
34
|
from qadence.utils import infer_batchsize, int_to_basis
|
35
35
|
|
36
36
|
from .config import Configuration, default_passes
|
@@ -52,6 +52,7 @@ class Backend(BackendInterface):
|
|
52
52
|
with_noise: bool = False
|
53
53
|
native_endianness: Endianness = Endianness.BIG
|
54
54
|
config: Configuration = field(default_factory=Configuration)
|
55
|
+
engine: Engine = Engine.TORCH
|
55
56
|
|
56
57
|
def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit:
|
57
58
|
passes = self.config.transpilation_passes
|
@@ -252,12 +252,12 @@ class PyQObservable(Module):
|
|
252
252
|
convert_block(block, n_qubits, config),
|
253
253
|
)
|
254
254
|
|
255
|
-
def forward(self, state: Tensor, values: dict[str, Tensor]) -> Tensor:
|
256
|
-
return pyq.overlap(state, self.operation(state, values))
|
257
|
-
|
258
255
|
def run(self, state: Tensor, values: dict[str, Tensor]) -> Tensor:
|
259
256
|
return self.operation(state, values)
|
260
257
|
|
258
|
+
def forward(self, state: Tensor, values: dict[str, Tensor]) -> Tensor:
|
259
|
+
return pyq.overlap(state, self.run(state, values))
|
260
|
+
|
261
261
|
|
262
262
|
class PyQHamiltonianEvolution(Module):
|
263
263
|
def __init__(
|
qadence/backends/utils.py
CHANGED
@@ -17,8 +17,8 @@ from torch import (
|
|
17
17
|
no_grad,
|
18
18
|
rand,
|
19
19
|
)
|
20
|
-
from torch import flatten as torchflatten
|
21
20
|
|
21
|
+
from qadence.types import ParamDictType
|
22
22
|
from qadence.utils import Endianness, int_to_basis, is_qadence_shape
|
23
23
|
|
24
24
|
FINITE_DIFF_EPS = 1e-06
|
@@ -92,7 +92,7 @@ def count_bitstrings(sample: Tensor, endianness: Endianness = Endianness.BIG) ->
|
|
92
92
|
)
|
93
93
|
|
94
94
|
|
95
|
-
def to_list_of_dicts(param_values:
|
95
|
+
def to_list_of_dicts(param_values: ParamDictType) -> list[ParamDictType]:
|
96
96
|
if not param_values:
|
97
97
|
return [param_values]
|
98
98
|
|
@@ -119,7 +119,7 @@ def pyqify(state: Tensor, n_qubits: int = None) -> Tensor:
|
|
119
119
|
|
120
120
|
def unpyqify(state: Tensor) -> Tensor:
|
121
121
|
"""Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits)."""
|
122
|
-
return
|
122
|
+
return torch.flatten(state, start_dim=0, end_dim=-2).t()
|
123
123
|
|
124
124
|
|
125
125
|
def is_pyq_shape(state: Tensor, n_qubits: int) -> bool:
|
@@ -141,6 +141,11 @@ def validate_state(state: Tensor, n_qubits: int) -> None:
|
|
141
141
|
)
|
142
142
|
|
143
143
|
|
144
|
+
def infer_batchsize(param_values: ParamDictType = None) -> int:
|
145
|
+
"""Infer the batch_size through the length of the parameter tensors."""
|
146
|
+
return max([len(tensor) for tensor in param_values.values()]) if param_values else 1
|
147
|
+
|
148
|
+
|
144
149
|
# The following functions can be used to compute potentially higher order gradients using pyqtorch's
|
145
150
|
# native 'jacobian' methods.
|
146
151
|
|
qadence/blocks/embedding.py
CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Callable, Iterable, List
|
4
4
|
|
5
|
-
import numpy as np
|
6
5
|
import sympy
|
7
|
-
import
|
8
|
-
import
|
9
|
-
from torch import
|
6
|
+
from numpy import array as nparray
|
7
|
+
from numpy import cdouble as npcdouble
|
8
|
+
from torch import tensor
|
10
9
|
|
11
10
|
from qadence.blocks import (
|
12
11
|
AbstractBlock,
|
@@ -16,9 +15,24 @@ from qadence.blocks.utils import (
|
|
16
15
|
parameters,
|
17
16
|
uuid_to_expression,
|
18
17
|
)
|
19
|
-
from qadence.parameters import evaluate,
|
18
|
+
from qadence.parameters import evaluate, make_differentiable, stringify
|
19
|
+
from qadence.types import ArrayLike, DifferentiableExpression, Engine, ParamDictType, TNumber
|
20
20
|
|
21
|
-
|
21
|
+
|
22
|
+
def _concretize_parameter(engine: Engine) -> Callable:
|
23
|
+
if engine == Engine.JAX:
|
24
|
+
from jax.numpy import array as jaxarray
|
25
|
+
from jax.numpy import float64 as jaxfloat64
|
26
|
+
|
27
|
+
def concretize_parameter(value: TNumber, trainable: bool = False) -> ArrayLike:
|
28
|
+
return jaxarray([value], dtype=jaxfloat64)
|
29
|
+
|
30
|
+
else:
|
31
|
+
|
32
|
+
def concretize_parameter(value: TNumber, trainable: bool = False) -> ArrayLike:
|
33
|
+
return tensor([value], requires_grad=trainable)
|
34
|
+
|
35
|
+
return concretize_parameter
|
22
36
|
|
23
37
|
|
24
38
|
def unique(x: Iterable) -> List:
|
@@ -26,14 +40,13 @@ def unique(x: Iterable) -> List:
|
|
26
40
|
|
27
41
|
|
28
42
|
def embedding(
|
29
|
-
block: AbstractBlock, to_gate_params: bool = False
|
30
|
-
) -> tuple[
|
31
|
-
"""Construct embedding function
|
43
|
+
block: AbstractBlock, to_gate_params: bool = False, engine: Engine = Engine.TORCH
|
44
|
+
) -> tuple[ParamDictType, Callable[[ParamDictType, ParamDictType], ParamDictType],]:
|
45
|
+
"""Construct embedding function which maps user-facing parameters to either *expression-level*.
|
32
46
|
|
33
|
-
|
34
|
-
parameters or *gate-level* parameters. The construced embedding function has the signature:
|
47
|
+
parameters or *gate-level* parameters. The constructed embedding function has the signature:
|
35
48
|
|
36
|
-
embedding_fn(params:
|
49
|
+
embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType:
|
37
50
|
|
38
51
|
which means that it maps the *variational* parameter dict `params` and the *feature* parameter
|
39
52
|
dict `inputs` to one new parameter dict `embedded_dict` which holds all parameters that are
|
@@ -56,6 +69,13 @@ def embedding(
|
|
56
69
|
Returns:
|
57
70
|
A tuple with variational parameter dict and the embedding function.
|
58
71
|
"""
|
72
|
+
concretize_parameter = _concretize_parameter(engine)
|
73
|
+
if engine == Engine.TORCH:
|
74
|
+
cast_dtype = tensor
|
75
|
+
else:
|
76
|
+
from jax.numpy import array
|
77
|
+
|
78
|
+
cast_dtype = array
|
59
79
|
|
60
80
|
unique_expressions = unique(expressions(block))
|
61
81
|
unique_symbols = [p for p in unique(parameters(block)) if not isinstance(p, sympy.Array)]
|
@@ -77,16 +97,18 @@ def embedding(
|
|
77
97
|
# we dont need to care about constant symbols if they are contained in an symbolic expression
|
78
98
|
# we only care about gate params which are ONLY a constant
|
79
99
|
|
80
|
-
embeddings: dict[sympy.Expr,
|
81
|
-
expr:
|
100
|
+
embeddings: dict[sympy.Expr, DifferentiableExpression] = {
|
101
|
+
expr: make_differentiable(expr=expr, engine=engine)
|
102
|
+
for expr in unique_expressions
|
103
|
+
if not expr.is_number
|
82
104
|
}
|
83
105
|
|
84
106
|
uuid_to_expr = uuid_to_expression(block)
|
85
107
|
|
86
|
-
def embedding_fn(params:
|
87
|
-
embedded_params: dict[sympy.Expr,
|
108
|
+
def embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType:
|
109
|
+
embedded_params: dict[sympy.Expr, ArrayLike] = {}
|
88
110
|
for expr, fn in embeddings.items():
|
89
|
-
angle:
|
111
|
+
angle: ArrayLike
|
90
112
|
values = {}
|
91
113
|
for symbol in expr.free_symbols:
|
92
114
|
if symbol.name in inputs:
|
@@ -112,26 +134,26 @@ def embedding(
|
|
112
134
|
embedded_params[e] = params[stringify(e)]
|
113
135
|
|
114
136
|
if to_gate_params:
|
115
|
-
gate_lvl_params:
|
137
|
+
gate_lvl_params: ParamDictType = {}
|
116
138
|
for uuid, e in uuid_to_expr.items():
|
117
139
|
gate_lvl_params[uuid] = embedded_params[e]
|
118
140
|
return gate_lvl_params
|
119
141
|
else:
|
120
142
|
return {stringify(k): v for k, v in embedded_params.items()}
|
121
143
|
|
122
|
-
params:
|
123
|
-
params = {
|
144
|
+
params: ParamDictType
|
145
|
+
params = {
|
146
|
+
p.name: concretize_parameter(value=p.value, trainable=True) for p in trainable_symbols
|
147
|
+
}
|
124
148
|
params.update(
|
125
149
|
{
|
126
|
-
stringify(expr):
|
150
|
+
stringify(expr): concretize_parameter(value=evaluate(expr), trainable=False)
|
127
151
|
for expr in constant_expressions
|
128
152
|
}
|
129
153
|
)
|
130
154
|
params.update(
|
131
155
|
{
|
132
|
-
stringify(expr):
|
133
|
-
np.array(expr.tolist(), dtype=np.cdouble), requires_grad=False
|
134
|
-
)
|
156
|
+
stringify(expr): cast_dtype(nparray(expr.tolist(), dtype=npcdouble))
|
135
157
|
for expr in unique_const_matrices
|
136
158
|
}
|
137
159
|
)
|
qadence/blocks/utils.py
CHANGED
@@ -5,7 +5,7 @@ from enum import Enum
|
|
5
5
|
from itertools import chain as _flatten
|
6
6
|
from typing import Generator, List, Type, TypeVar, Union, get_args
|
7
7
|
|
8
|
-
from sympy import Basic, Expr
|
8
|
+
from sympy import Array, Basic, Expr
|
9
9
|
from torch import Tensor
|
10
10
|
|
11
11
|
from qadence.blocks import (
|
@@ -503,3 +503,22 @@ def assert_same_block(b1: AbstractBlock, b2: AbstractBlock) -> None:
|
|
503
503
|
), f"Blocks {b1} and {b2} have differing numbers of parameters."
|
504
504
|
for p1, p2 in zip(b1.parameters.expressions(), b2.parameters.expressions()):
|
505
505
|
assert p1 == p2
|
506
|
+
|
507
|
+
|
508
|
+
def unique_parameters(block: AbstractBlock) -> list[Parameter]:
|
509
|
+
"""Return the unique parameters in the block.
|
510
|
+
|
511
|
+
These parameters are the actual user-facing parameters which
|
512
|
+
can be assigned by the user. Multiple gates can contain the
|
513
|
+
same unique parameter
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
list[Parameter]: List of unique parameters in the circuit
|
517
|
+
"""
|
518
|
+
symbols = []
|
519
|
+
for p in parameters(block):
|
520
|
+
if isinstance(p, Array):
|
521
|
+
continue
|
522
|
+
elif not p.is_number and p not in symbols:
|
523
|
+
symbols.append(p)
|
524
|
+
return symbols
|
qadence/circuit.py
CHANGED
@@ -5,10 +5,10 @@ from itertools import chain as flatten
|
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Iterable
|
7
7
|
|
8
|
-
from sympy import
|
8
|
+
from sympy import Basic
|
9
9
|
|
10
10
|
from qadence.blocks import AbstractBlock, AnalogBlock, CompositeBlock, chain
|
11
|
-
from qadence.blocks.utils import parameters, primitive_blocks
|
11
|
+
from qadence.blocks.utils import parameters, primitive_blocks, unique_parameters
|
12
12
|
from qadence.parameters import Parameter
|
13
13
|
from qadence.register import Register
|
14
14
|
|
@@ -88,13 +88,7 @@ class QuantumCircuit:
|
|
88
88
|
Returns:
|
89
89
|
list[Parameter]: List of unique parameters in the circuit
|
90
90
|
"""
|
91
|
-
|
92
|
-
for p in parameters(self.block):
|
93
|
-
if isinstance(p, Array):
|
94
|
-
continue
|
95
|
-
elif not p.is_number and p not in symbols:
|
96
|
-
symbols.append(p)
|
97
|
-
return symbols
|
91
|
+
return unique_parameters(self.block)
|
98
92
|
|
99
93
|
@property
|
100
94
|
def num_unique_parameters(self) -> int:
|
File without changes
|