GEOPE 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geope/__init__.py +72 -0
- geope/engine.py +622 -0
- geope/gecko.py +1120 -0
- geope/geope.py +953 -0
- geope/grape.py +496 -0
- geope/jax/__init__.py +26 -0
- geope/jax/dexpm.py +524 -0
- geope/jax/hessian.py +123 -0
- geope/jax/jacobian.py +122 -0
- geope/jax/logm.py +826 -0
- geope/lie/__init__.py +3 -0
- geope/lie/basis.py +384 -0
- geope/lie/hamiltonian.py +90 -0
- geope/lie/pauli_projector.py +132 -0
- geope/lie/unitary.py +144 -0
- geope/line_searches.py +117 -0
- geope/parameters.py +487 -0
- geope/utils/__init__.py +31 -0
- geope/utils/history.py +170 -0
- geope/utils/utils.py +1096 -0
- geope-0.0.3.dist-info/METADATA +56 -0
- geope-0.0.3.dist-info/RECORD +25 -0
- geope-0.0.3.dist-info/WHEEL +5 -0
- geope-0.0.3.dist-info/licenses/LICENSE.md +57 -0
- geope-0.0.3.dist-info/top_level.txt +1 -0
geope/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from .engine import (
|
|
2
|
+
fidelity,
|
|
3
|
+
infidelity,
|
|
4
|
+
fidelity_full,
|
|
5
|
+
infidelity_full,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from .geope import (
|
|
9
|
+
Geope,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from . import line_searches
|
|
13
|
+
from .line_searches import (
|
|
14
|
+
LineSearch,
|
|
15
|
+
Adam,
|
|
16
|
+
GoldenSection,
|
|
17
|
+
adam,
|
|
18
|
+
golden_section,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from .gecko import (
|
|
22
|
+
Gecko,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from .grape import (
|
|
26
|
+
Grape,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from .parameters import (
|
|
30
|
+
Parameters,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from .utils import (
|
|
34
|
+
History,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
from .lie import (
|
|
38
|
+
Basis,
|
|
39
|
+
Hamiltonian,
|
|
40
|
+
Unitary,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from .utils import (
|
|
44
|
+
trace_dot_jit,
|
|
45
|
+
traces,
|
|
46
|
+
check_xy_comb,
|
|
47
|
+
check_Heisenberg_comb,
|
|
48
|
+
check_2_local_comb,
|
|
49
|
+
restriction_function,
|
|
50
|
+
restriction_order_function,
|
|
51
|
+
control_to_indices,
|
|
52
|
+
filter_basis_by_control,
|
|
53
|
+
make_per_element_transform,
|
|
54
|
+
construct_restricted_pauli_basis,
|
|
55
|
+
construct_Heisenberg_pauli_basis,
|
|
56
|
+
construct_two_body_pauli_basis,
|
|
57
|
+
construct_full_pauli_basis,
|
|
58
|
+
creation_annihilation_operators,
|
|
59
|
+
construct_full_spin_boson_basis,
|
|
60
|
+
construct_restricted_spin_boson_basis,
|
|
61
|
+
prepare_random_parameters,
|
|
62
|
+
construct_commuting_ansatz_matrix,
|
|
63
|
+
remove_solution_free_parameters,
|
|
64
|
+
multikron,
|
|
65
|
+
multimatmul,
|
|
66
|
+
multicontrol_unitary,
|
|
67
|
+
qft_unitary,
|
|
68
|
+
golden_section_search_np,
|
|
69
|
+
golden_section_search,
|
|
70
|
+
adam_line_search,
|
|
71
|
+
merge_constraints,
|
|
72
|
+
)
|
geope/engine.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from jax import Array
|
|
8
|
+
|
|
9
|
+
jax.config.update("jax_enable_x64", True)
|
|
10
|
+
|
|
11
|
+
from .jax.logm import logm
|
|
12
|
+
from .jax.dexpm import (
|
|
13
|
+
get_Ui_fn,
|
|
14
|
+
get_dexpm,
|
|
15
|
+
get_dexpm_eig,
|
|
16
|
+
get_d2expm,
|
|
17
|
+
get_d2expm_eig,
|
|
18
|
+
)
|
|
19
|
+
from .jax.jacobian import manual_jacobian
|
|
20
|
+
from .jax.hessian import manual_hessian
|
|
21
|
+
|
|
22
|
+
import inspect
|
|
23
|
+
from functools import partial
|
|
24
|
+
from typing import Callable, TYPE_CHECKING
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .parameters import Parameters
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def fidelity(unitary: Array, target_unitary: Array) -> Array:
|
|
31
|
+
"""Compute the fidelity between a unitary and a target unitary.
|
|
32
|
+
|
|
33
|
+
The fidelity is defined as the normalised absolute value of the
|
|
34
|
+
Hilbert-Schmidt inner product between the two matrices.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
unitary: The unitary ``Array`` to evaluate.
|
|
38
|
+
target_unitary: The target unitary ``Array``.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A scalar fidelity ``Array`` in the range $[0, 1]$.
|
|
42
|
+
"""
|
|
43
|
+
return jnp.abs(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
|
|
44
|
+
target_unitary[0]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_fidelity_fn(target_unitary: Array) -> Callable[[Array], Array]:
|
|
49
|
+
"""Create a partial fidelity function with a fixed target unitary.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
target_unitary: The target unitary ``Array`` to bind.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A ``Callable[[Array], Array]`` that accepts a single unitary
|
|
56
|
+
and returns the fidelity against ``target_unitary``.
|
|
57
|
+
"""
|
|
58
|
+
return partial(fidelity, target_unitary=target_unitary)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def infidelity(unitary: Array, target_unitary: Array) -> Array:
|
|
62
|
+
"""Projective infidelity $1 - F_{\\mathrm{proj}}(U, U_T)$.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
unitary: The unitary ``Array`` to evaluate.
|
|
66
|
+
target_unitary: The target unitary ``Array``.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A scalar infidelity ``Array`` in $[0, 1]$.
|
|
70
|
+
"""
|
|
71
|
+
return 1 - jnp.abs(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
|
|
72
|
+
target_unitary[0]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_infidelity_fn(target_unitary: Array) -> Callable[[Array], Array]:
|
|
77
|
+
"""Create a partial projective-infidelity function with a fixed target.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
target_unitary: The target unitary ``Array`` to bind.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A ``Callable[[Array], Array]`` returning $1 - F_{\\mathrm{proj}}$.
|
|
84
|
+
"""
|
|
85
|
+
return partial(infidelity, target_unitary=target_unitary)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def fidelity_full(unitary: Array, target_unitary: Array) -> Array:
|
|
89
|
+
"""Phase-sensitive (non-projective) fidelity.
|
|
90
|
+
|
|
91
|
+
$F_{\\mathrm{full}}(U, U_T) = \\mathrm{Re}\\,\\mathrm{Tr}(U_T^\\dagger U) / d$.
|
|
92
|
+
Unlike the projective fidelity, this is sensitive to a global phase
|
|
93
|
+
on $U$ and lies in $[-1, 1]$.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
unitary: The unitary ``Array`` to evaluate.
|
|
97
|
+
target_unitary: The target unitary ``Array``.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A scalar fidelity ``Array`` in $[-1, 1]$.
|
|
101
|
+
"""
|
|
102
|
+
return jnp.real(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
|
|
103
|
+
target_unitary[0]
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_fidelity_full_fn(target_unitary: Array) -> Callable[[Array], Array]:
|
|
108
|
+
"""Create a partial phase-sensitive fidelity function with a fixed target.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
target_unitary: The target unitary ``Array`` to bind.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
A ``Callable[[Array], Array]`` returning $F_{\\mathrm{full}}$.
|
|
115
|
+
"""
|
|
116
|
+
return partial(fidelity_full, target_unitary=target_unitary)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def infidelity_full(unitary: Array, target_unitary: Array) -> Array:
|
|
120
|
+
"""Phase-sensitive infidelity $1 - F_{\\mathrm{full}}(U, U_T)$.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
unitary: The unitary ``Array`` to evaluate.
|
|
124
|
+
target_unitary: The target unitary ``Array``.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A scalar infidelity ``Array`` in $[0, 2]$.
|
|
128
|
+
"""
|
|
129
|
+
return 1 - jnp.real(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
|
|
130
|
+
target_unitary[0]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_infidelity_full_fn(target_unitary: Array) -> Callable[[Array], Array]:
|
|
135
|
+
"""Create a partial phase-sensitive infidelity function with a fixed target.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
target_unitary: The target unitary ``Array`` to bind.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
A ``Callable[[Array], Array]`` returning $1 - F_{\\mathrm{full}}$.
|
|
142
|
+
"""
|
|
143
|
+
return partial(infidelity_full, target_unitary=target_unitary)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compute_matrices_params_list_fn(params_list: Array, basis: Array) -> Array:
|
|
147
|
+
"""Compute the product unitary from a list of parameter vectors.
|
|
148
|
+
|
|
149
|
+
For each parameter vector in `params_list`, constructs a Hamiltonian
|
|
150
|
+
as a linear combination of the `basis` elements, exponentiates it,
|
|
151
|
+
and accumulates the product unitary via `jax.lax.scan`.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
params_list: ``Array`` of shape ``(piecewise_steps, K)`` where each row
|
|
155
|
+
contains the Lie-algebra coefficients for one gate segment.
|
|
156
|
+
basis: ``Array`` of shape ``(K, d, d)`` of Hermitian basis matrices.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The product unitary ``Array`` of shape ``(d, d)``.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def step(U, params):
|
|
163
|
+
A = jnp.tensordot(params, basis, axes=[[-1], [0]])
|
|
164
|
+
Ui = jax.scipy.linalg.expm(1j * A)
|
|
165
|
+
U_new = jnp.matmul(Ui, U)
|
|
166
|
+
return U_new, None
|
|
167
|
+
|
|
168
|
+
U0 = jnp.eye(basis.shape[1], dtype=basis.dtype)
|
|
169
|
+
U_final, _ = jax.lax.scan(step, U0, jnp.stack(params_list))
|
|
170
|
+
return U_final
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_compute_matrices_params_list_fn(basis: np.ndarray) -> Callable[[Array], Array]:
|
|
174
|
+
"""Create a partial unitary-computation function with a fixed basis.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
basis: Array of shape ``(K, d, d)`` of Hermitian basis matrices.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
A ``Callable[[Array], Array]`` that accepts a parameter list
|
|
181
|
+
and returns the product unitary.
|
|
182
|
+
"""
|
|
183
|
+
return partial(compute_matrices_params_list_fn, basis=basis)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def geodesic_hamiltonian(
|
|
187
|
+
unitary: Array,
|
|
188
|
+
target_unitary: Array,
|
|
189
|
+
projective: bool = True,
|
|
190
|
+
key: Array = jax.random.key(0),
|
|
191
|
+
) -> Array:
|
|
192
|
+
"""Compute the geodesic Hamiltonian between a unitary and a target.
|
|
193
|
+
|
|
194
|
+
Computes the generator $g = -i\\log(U^\\dagger U_T) \\in \\mathfrak{u}(d)$
|
|
195
|
+
and returns $U g'$ where $g' = g - \\frac{\\mathrm{Tr}(g)}{d}\\mathbb{1}$
|
|
196
|
+
(the SU part) when ``projective=True``, or $g' = g$ (full U) when
|
|
197
|
+
``projective=False``.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
unitary: The current unitary ``Array``.
|
|
201
|
+
target_unitary: The target unitary ``Array``.
|
|
202
|
+
projective: If ``True``, subtract the global-phase generator
|
|
203
|
+
(SU geodesic). If ``False``, keep it (U geodesic).
|
|
204
|
+
Defaults to ``True``.
|
|
205
|
+
key: JAX random key forwarded to ``logm``. Defaults to
|
|
206
|
+
``jax.random.key(0)``.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
The geodesic tangent ``Array`` $U g'$ at the current unitary.
|
|
210
|
+
"""
|
|
211
|
+
g = -1.0j * logm(jnp.einsum("ji,jk->ik", unitary.conj(), target_unitary), key=key)
|
|
212
|
+
if projective:
|
|
213
|
+
Id = jnp.eye(g.shape[0])
|
|
214
|
+
global_phase = jnp.real(jnp.einsum("ij,ji->", Id, g)) / g.shape[0]
|
|
215
|
+
g = g - global_phase * Id
|
|
216
|
+
return unitary @ g
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_geodesic_hamiltonian_fn(
|
|
220
|
+
target_unitary: Array, projective: bool = True
|
|
221
|
+
) -> Callable[[Array, Array], Array]:
|
|
222
|
+
"""Create a partial geodesic Hamiltonian function with a fixed target.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
target_unitary: The target unitary ``Array`` to bind.
|
|
226
|
+
projective: If ``True``, return the projective (SU) geodesic.
|
|
227
|
+
Defaults to ``True``.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A ``Callable[[Array, Array], Array]`` that accepts a unitary and a
|
|
231
|
+
JAX random key and returns the geodesic Hamiltonian.
|
|
232
|
+
"""
|
|
233
|
+
return partial(
|
|
234
|
+
geodesic_hamiltonian, target_unitary=target_unitary, projective=projective
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def hvp_forward_over_reverse(
|
|
239
|
+
f: Callable[[Array], Array], params: Array, v: Array
|
|
240
|
+
) -> Array:
|
|
241
|
+
"""Compute a Hessian-vector product via forward-over-reverse mode.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
f: Scalar-valued callable of ``params``.
|
|
245
|
+
params: Parameter ``Array`` at which to evaluate.
|
|
246
|
+
v: Tangent ``Array`` for the Hessian-vector product.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
The Hessian-vector product $\\nabla^2 f \\cdot v$.
|
|
250
|
+
"""
|
|
251
|
+
v = v.reshape(params.shape)
|
|
252
|
+
return jax.jvp(jax.grad(f), (params,), (v,))[1]
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_jacobian_fn(compute_U_fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
|
|
256
|
+
"""Build the autodiff Jacobian of the unitary w.r.t. parameters.
|
|
257
|
+
|
|
258
|
+
Returns the holomorphic ``jax.jacobian`` of ``compute_U_fn``. This is the
|
|
259
|
+
live Jacobian path for *all* system sizes: the manual Jacobian
|
|
260
|
+
(``geope.jax.jacobian.get_jacobian_propagator``) exists and is independently
|
|
261
|
+
tested, but is not currently wired into the optimisation pipeline (the
|
|
262
|
+
autodiff path historically overwrote it for the >5-qubit branch — see
|
|
263
|
+
issue #4). The returned function is left un-jitted so it fuses into the
|
|
264
|
+
enclosing ``@jax.jit`` update step on first ``optimize()``.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
compute_U_fn: Callable mapping a parameter list to the product unitary.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
A ``Callable[[Array], Array]`` returning the Jacobian of the unitary.
|
|
271
|
+
"""
|
|
272
|
+
return jax.jacobian(compute_U_fn, argnums=0, holomorphic=True)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def get_gammas_fn(
|
|
276
|
+
compute_U_fn: Callable[[Array], Array],
|
|
277
|
+
geo_fn: Callable[..., Array],
|
|
278
|
+
project_omegas_fn: Callable[[Array], Array],
|
|
279
|
+
) -> Callable[[Array, Array], Array]:
|
|
280
|
+
"""Build the projected geodesic-Hamiltonian (``gammas``) function.
|
|
281
|
+
|
|
282
|
+
Computes the unitary, its geodesic Hamiltonian towards the target, and
|
|
283
|
+
projects that onto the Pauli basis (normalised by the dimension). Returned
|
|
284
|
+
un-jitted so it composes inside an enclosing ``@jax.jit``.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
compute_U_fn: Parameter-list -> unitary.
|
|
288
|
+
geo_fn: ``(unitary, key) -> geodesic Hamiltonian``.
|
|
289
|
+
project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A ``Callable[[Array, Array], Array]`` ``gammas(free_params, key)``.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def gammas(free_params: Array, key: Array) -> Array:
|
|
296
|
+
unitary = compute_U_fn(free_params)
|
|
297
|
+
gammaU = geo_fn(unitary, key=key) # seed for logm
|
|
298
|
+
return project_omegas_fn(jnp.expand_dims(gammaU, axis=0)).squeeze(axis=0) / (
|
|
299
|
+
gammaU.shape[0]
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
return gammas
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def get_omegas_fn(
|
|
306
|
+
jac_fn: Callable[[Array], Array],
|
|
307
|
+
project_omegas_fn: Callable[[Array], Array],
|
|
308
|
+
proj_indices: np.ndarray,
|
|
309
|
+
has_proj_drift: bool,
|
|
310
|
+
) -> Callable[[Array], Array]:
|
|
311
|
+
"""Build the projected per-gate Jacobian (``omegas``) function.
|
|
312
|
+
|
|
313
|
+
Projects the Jacobian of each gate (w.r.t. each parameter) onto the Pauli
|
|
314
|
+
basis, optionally restricting to the projected indices within the combined
|
|
315
|
+
proj+drift basis. Returned un-jitted so it composes inside an enclosing
|
|
316
|
+
``@jax.jit``.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
jac_fn: Jacobian of the unitary w.r.t. the free parameters.
|
|
320
|
+
project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
|
|
321
|
+
proj_indices: Projected indices within the proj+drift basis.
|
|
322
|
+
has_proj_drift: Whether the proj+drift basis is non-empty (gates the
|
|
323
|
+
projected-index restriction; mirrors the legacy
|
|
324
|
+
``np.any(proj_drift_basis)`` check).
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
A ``Callable[[Array], Array]`` ``omegas(free_params)``.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
def omegas(free_params: Array) -> Array:
|
|
331
|
+
dUs = jnp.array(jac_fn(free_params))
|
|
332
|
+
dUs_t = jnp.transpose(dUs, [2, 3, 0, 1])
|
|
333
|
+
omegas_steps_phis = jnp.array(
|
|
334
|
+
[project_omegas_fn(1.0j * omegaUs) for omegaUs in dUs_t]
|
|
335
|
+
)
|
|
336
|
+
if has_proj_drift:
|
|
337
|
+
omegas_steps_phis = omegas_steps_phis.at[:, proj_indices, :].get()
|
|
338
|
+
return omegas_steps_phis
|
|
339
|
+
|
|
340
|
+
return omegas
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def get_gammas_and_omegas_fn(
|
|
344
|
+
compute_U_fn: Callable[[Array], Array],
|
|
345
|
+
jac_fn: Callable[[Array], Array],
|
|
346
|
+
geo_fn: Callable[..., Array],
|
|
347
|
+
project_omegas_fn: Callable[[Array], Array],
|
|
348
|
+
proj_indices: np.ndarray,
|
|
349
|
+
has_proj_drift: bool,
|
|
350
|
+
) -> Callable[[Array, Array], tuple[Array, Array]]:
|
|
351
|
+
"""Build the combined gammas-and-omegas function used by the GEOPE step.
|
|
352
|
+
|
|
353
|
+
Gammas are the projected geodesic Hamiltonian coefficients; omegas encode
|
|
354
|
+
the Jacobian of each gate w.r.t. each parameter, projected onto the Pauli
|
|
355
|
+
basis. This is the single combined body the GEOPE update step calls (one
|
|
356
|
+
``compute_U_fn`` and one ``jac_fn`` evaluation), matching the legacy
|
|
357
|
+
numerics; :func:`get_gammas_fn` / :func:`get_omegas_fn` are the separately
|
|
358
|
+
testable halves. Returned un-jitted so it fuses into the enclosing
|
|
359
|
+
``@jax.jit`` update step on first ``optimize()``.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
compute_U_fn: Parameter-list -> unitary.
|
|
363
|
+
jac_fn: Jacobian of the unitary w.r.t. the free parameters.
|
|
364
|
+
geo_fn: ``(unitary, key) -> geodesic Hamiltonian``.
|
|
365
|
+
project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
|
|
366
|
+
proj_indices: Projected indices within the proj+drift basis.
|
|
367
|
+
has_proj_drift: Whether the proj+drift basis is non-empty.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
A ``Callable[[Array, Array], tuple[Array, Array]]``
|
|
371
|
+
``gammas_and_omegas(free_params, key) -> (gammaU_params, omegas)``.
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
def gammas_and_omegas(free_params: Array, key: Array) -> tuple[Array, Array]:
|
|
375
|
+
unitary = compute_U_fn(free_params)
|
|
376
|
+
gammaU = geo_fn(unitary, key=key) # seed for logm
|
|
377
|
+
gammaU_params = project_omegas_fn(jnp.expand_dims(gammaU, axis=0)).squeeze(
|
|
378
|
+
axis=0
|
|
379
|
+
) / (gammaU.shape[0])
|
|
380
|
+
|
|
381
|
+
dUs = jnp.array(jac_fn(free_params))
|
|
382
|
+
dUs_t = jnp.transpose(dUs, [2, 3, 0, 1])
|
|
383
|
+
omegas_steps_phis = jnp.array(
|
|
384
|
+
[project_omegas_fn(1.0j * omegaUs) for omegaUs in dUs_t]
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if has_proj_drift:
|
|
388
|
+
omegas_steps_phis = omegas_steps_phis.at[:, proj_indices, :].get()
|
|
389
|
+
|
|
390
|
+
return gammaU_params, omegas_steps_phis
|
|
391
|
+
|
|
392
|
+
return gammas_and_omegas
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def get_hessian_fn(infid_fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
|
|
396
|
+
"""Build the full Hessian function via forward-over-reverse HVPs.
|
|
397
|
+
|
|
398
|
+
Materialises the Hessian of ``infid_fn`` by mapping a Hessian-vector
|
|
399
|
+
product over the identity matrix's columns. Returned un-jitted so it fuses
|
|
400
|
+
into the enclosing ``@jax.jit`` update step.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
infid_fn: Scalar-valued infidelity callable of the free parameters.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
A ``Callable[[Array], Array]`` ``hess(y)`` returning the Hessian.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def hess(y: Array) -> Array:
|
|
410
|
+
return jax.vmap(lambda x: hvp_forward_over_reverse(infid_fn, y, x))(
|
|
411
|
+
jnp.eye(y.size, dtype=y.dtype)
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
return hess
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def get_hessian_propagator_fn(
|
|
418
|
+
basis: np.ndarray,
|
|
419
|
+
target: Array,
|
|
420
|
+
projective: bool = True,
|
|
421
|
+
method: str = "eig",
|
|
422
|
+
hermitian: bool = True,
|
|
423
|
+
) -> Callable[[Array], Array]:
|
|
424
|
+
r"""Build the infidelity Hessian manually (Goodwin–Kuprov NR-GRAPE).
|
|
425
|
+
|
|
426
|
+
Analytic drop-in for `get_hessian_fn`: returns ``hess(y) -> (P, P)`` with
|
|
427
|
+
``P = y.size``, the Hessian of the same infidelity that `get_hessian_fn`
|
|
428
|
+
differentiates by autodiff, but built from the manual propagator
|
|
429
|
+
derivatives (`manual_jacobian`, `manual_hessian`) rather than from
|
|
430
|
+
forward-over-reverse HVPs.
|
|
431
|
+
|
|
432
|
+
Let $z = \mathrm{Tr}(U_T^\dagger U)$, $\partial_a z$, $\partial_a\partial_b z$
|
|
433
|
+
be obtained by contracting $U$, $\partial U$, $\partial^2 U$ against
|
|
434
|
+
$U_T^\dagger$. For the phase-sensitive cost $C = 1 - \mathrm{Re}(z)/d$ the
|
|
435
|
+
Hessian is the linear contraction $-\mathrm{Re}(\partial_a\partial_b z)/d$;
|
|
436
|
+
for the projective cost $C = 1 - |z|/d$,
|
|
437
|
+
|
|
438
|
+
$$\partial_a\partial_b|z| =
|
|
439
|
+
\frac{\mathrm{Re}(\overline{\partial_a z}\,\partial_b z)
|
|
440
|
+
+ \mathrm{Re}(\bar z\,\partial_a\partial_b z)}{|z|}
|
|
441
|
+
- \frac{\mathrm{Re}(\bar z\,\partial_a z)\,
|
|
442
|
+
\mathrm{Re}(\bar z\,\partial_b z)}{|z|^3}.$$
|
|
443
|
+
|
|
444
|
+
Like the projective fidelity itself, this is singular as $|z| \to 0$ (the
|
|
445
|
+
near-identity / traceless-target gotcha) — the autodiff Hessian shares that.
|
|
446
|
+
|
|
447
|
+
Memory note: this materialises the dense propagator Hessian
|
|
448
|
+
(`manual_hessian`, $O(G^2 d^2 K^2)$); intended for the small systems where
|
|
449
|
+
NR-GRAPE is used.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
basis: Proj+drift basis ``(K, d, d)`` — the same basis the bound
|
|
453
|
+
``compute_U_fn`` uses.
|
|
454
|
+
target: Target unitary ``(d, d)``.
|
|
455
|
+
projective: Match the projective (``True``) or phase-sensitive
|
|
456
|
+
(``False``) infidelity.
|
|
457
|
+
method: ``"eig"`` (spectral, default) or ``"block"`` (auxiliary-matrix)
|
|
458
|
+
per-step derivatives.
|
|
459
|
+
hermitian: For ``method="eig"``, assume real parameters (skew-Hermitian
|
|
460
|
+
generators) and use the faster ``eigh``-based spectral derivatives.
|
|
461
|
+
Set ``False`` for complex-valued parameters.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
A ``Callable[[Array], Array]`` ``hess(y)`` returning the ``(P, P)``
|
|
465
|
+
infidelity Hessian. Left un-jitted so it fuses into the enclosing
|
|
466
|
+
``@jax.jit`` update step.
|
|
467
|
+
"""
|
|
468
|
+
Ui_fn = get_Ui_fn(basis)
|
|
469
|
+
if method == "eig":
|
|
470
|
+
jac_step = get_dexpm_eig(basis, hermitian=hermitian)
|
|
471
|
+
hess_step = get_d2expm_eig(basis, hermitian=hermitian)
|
|
472
|
+
elif method == "block":
|
|
473
|
+
jac_step = get_dexpm(basis)
|
|
474
|
+
hess_step = get_d2expm(basis)
|
|
475
|
+
else:
|
|
476
|
+
raise ValueError(f"Unknown method {method!r}; expected 'eig' or 'block'.")
|
|
477
|
+
|
|
478
|
+
compute_U = get_compute_matrices_params_list_fn(basis)
|
|
479
|
+
t_conj = jnp.asarray(target).conj()
|
|
480
|
+
d = jnp.asarray(target).shape[0]
|
|
481
|
+
|
|
482
|
+
def hess(y: Array) -> Array:
|
|
483
|
+
U = compute_U(y)
|
|
484
|
+
dU = manual_jacobian(y, Ui_fn, jac_step) # (G, d, d, K)
|
|
485
|
+
H = manual_hessian(y, Ui_fn, jac_step, hess_step) # (G, G, d, d, K, K)
|
|
486
|
+
|
|
487
|
+
# Contract the propagator and its derivatives with U_T^dagger.
|
|
488
|
+
z = jnp.einsum("ab,ab->", t_conj, U)
|
|
489
|
+
dz = jnp.einsum("ab,iabk->ik", t_conj, dU) # (G, K)
|
|
490
|
+
d2z = jnp.einsum("ab,ijabkl->ijkl", t_conj, H) # (G, G, K, K)
|
|
491
|
+
|
|
492
|
+
n_g, n_k = y.shape
|
|
493
|
+
P = n_g * n_k
|
|
494
|
+
dz_f = dz.reshape(P)
|
|
495
|
+
d2z_f = jnp.transpose(d2z, (0, 2, 1, 3)).reshape(P, P)
|
|
496
|
+
|
|
497
|
+
if not projective:
|
|
498
|
+
return -jnp.real(d2z_f) / d
|
|
499
|
+
|
|
500
|
+
r = jnp.abs(z)
|
|
501
|
+
z_bar = jnp.conj(z)
|
|
502
|
+
re_zdz = jnp.real(z_bar * dz_f) # (P,)
|
|
503
|
+
term1 = (
|
|
504
|
+
jnp.real(jnp.outer(jnp.conj(dz_f), dz_f)) + jnp.real(z_bar * d2z_f)
|
|
505
|
+
) / r
|
|
506
|
+
term2 = jnp.outer(re_zdz, re_zdz) / r**3
|
|
507
|
+
return -(term1 - term2) / d
|
|
508
|
+
|
|
509
|
+
return hess
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def wrap_compute_U_param_transform(
|
|
513
|
+
params: "Parameters", raw_compute_U: Callable[[Array], Array]
|
|
514
|
+
) -> Callable[[Array], Array]:
|
|
515
|
+
"""Wrap ``compute_U`` to honour ``params.param_transform``.
|
|
516
|
+
|
|
517
|
+
The user-facing experimental parameters $\\phi^{\\mathrm{exp}}$ are mapped to
|
|
518
|
+
projected-basis coefficients via ``params.param_transform`` (possibly
|
|
519
|
+
step-dependent), embedded into the proj+drift basis, and combined with the
|
|
520
|
+
drift before the original ``raw_compute_U`` is called.
|
|
521
|
+
|
|
522
|
+
Returned un-jitted so it fuses into the enclosing ``@jax.jit`` update step
|
|
523
|
+
on first ``optimize()``.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
params: The ``Parameters`` object carrying ``param_transform``.
|
|
527
|
+
raw_compute_U: The projected-basis unitary-computation function.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
The wrapped experimental-space ``compute_U`` callable.
|
|
531
|
+
"""
|
|
532
|
+
n_exp = params.n_experimental_params
|
|
533
|
+
n_proj_drift = params.proj_drift_basis.lie_algebra_dim
|
|
534
|
+
proj_idx_pd = params.proj_indices_projdrift_basis
|
|
535
|
+
drift_idx_pd = params.drift_indices_projdrift_basis
|
|
536
|
+
|
|
537
|
+
# Detect step-dependence: tau(phi) vs tau(phi, step_index)
|
|
538
|
+
_step_dependent = len(inspect.signature(params.param_transform).parameters) >= 2
|
|
539
|
+
|
|
540
|
+
# Detect whether transform outputs full-basis or projected-basis coefficients
|
|
541
|
+
_test_out = (
|
|
542
|
+
params.param_transform(jnp.zeros(n_exp), 0)
|
|
543
|
+
if _step_dependent
|
|
544
|
+
else params.param_transform(jnp.zeros(n_exp))
|
|
545
|
+
)
|
|
546
|
+
tf_out_dim = _test_out.shape[0]
|
|
547
|
+
n_proj = params.projected_basis.lie_algebra_dim
|
|
548
|
+
if tf_out_dim != n_proj:
|
|
549
|
+
_extract = jnp.array(
|
|
550
|
+
np.where(np.array(params.projected_basis.overlap(params.basis)))[0]
|
|
551
|
+
)
|
|
552
|
+
else:
|
|
553
|
+
_extract = None
|
|
554
|
+
|
|
555
|
+
if params.drift_parameters is not None:
|
|
556
|
+
_drift = jnp.array(params.drift_parameters, dtype=jnp.float64)
|
|
557
|
+
else:
|
|
558
|
+
_drift = None
|
|
559
|
+
|
|
560
|
+
def _wrapped_compute_U(
|
|
561
|
+
exp_params,
|
|
562
|
+
_raw=raw_compute_U,
|
|
563
|
+
_tf=params.param_transform,
|
|
564
|
+
_pi=proj_idx_pd,
|
|
565
|
+
_di=drift_idx_pd,
|
|
566
|
+
_npd=n_proj_drift,
|
|
567
|
+
_dr=_drift,
|
|
568
|
+
_ext=_extract,
|
|
569
|
+
_step_dep=_step_dependent,
|
|
570
|
+
):
|
|
571
|
+
if _step_dep:
|
|
572
|
+
ctrl = jax.vmap(_tf)(exp_params, jnp.arange(exp_params.shape[0]))
|
|
573
|
+
else:
|
|
574
|
+
ctrl = jax.vmap(_tf)(exp_params)
|
|
575
|
+
if _ext is not None:
|
|
576
|
+
ctrl = ctrl[:, _ext]
|
|
577
|
+
# Promote dtype so complex tracing through real intermediates works
|
|
578
|
+
_dtype = jnp.result_type(ctrl.dtype, exp_params.dtype)
|
|
579
|
+
ctrl = ctrl.astype(_dtype)
|
|
580
|
+
full = jnp.zeros((exp_params.shape[0], _npd), dtype=_dtype)
|
|
581
|
+
full = full.at[:, _pi].set(ctrl)
|
|
582
|
+
if _dr is not None:
|
|
583
|
+
full = full.at[:, _di].set(
|
|
584
|
+
jnp.broadcast_to(
|
|
585
|
+
_dr.astype(_dtype), (exp_params.shape[0], _dr.shape[0])
|
|
586
|
+
)
|
|
587
|
+
)
|
|
588
|
+
return _raw(full)
|
|
589
|
+
|
|
590
|
+
return _wrapped_compute_U
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def get_split_jacobian_fn(
|
|
594
|
+
compute_U_fn: Callable[[Array], Array],
|
|
595
|
+
) -> Callable[[Array], Array]:
|
|
596
|
+
"""Build a real/imag-split Jacobian of ``compute_U_fn``.
|
|
597
|
+
|
|
598
|
+
Used on the ``param_transform`` path: differentiating through the
|
|
599
|
+
real-valued user transform with a holomorphic Jacobian would discard the
|
|
600
|
+
imaginary part of intermediates, so the unitary is split into real and
|
|
601
|
+
imaginary parts, each differentiated, then recombined.
|
|
602
|
+
|
|
603
|
+
Returned un-jitted so it fuses into the enclosing ``@jax.jit`` update step.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
compute_U_fn: The (wrapped) experimental-space unitary function.
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
A ``Callable[[Array], Array]`` returning the complex Jacobian.
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
def _split_U(x):
|
|
613
|
+
U = compute_U_fn(x)
|
|
614
|
+
return jnp.stack([jnp.real(U), jnp.imag(U)])
|
|
615
|
+
|
|
616
|
+
_raw_jac_split = jax.jacobian(_split_U, argnums=0)
|
|
617
|
+
|
|
618
|
+
def _jac_fn(x):
|
|
619
|
+
jac_split = _raw_jac_split(x)
|
|
620
|
+
return jac_split[0] + 1j * jac_split[1]
|
|
621
|
+
|
|
622
|
+
return _jac_fn
|