GEOPE 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geope/__init__.py ADDED
@@ -0,0 +1,72 @@
1
+ from .engine import (
2
+ fidelity,
3
+ infidelity,
4
+ fidelity_full,
5
+ infidelity_full,
6
+ )
7
+
8
+ from .geope import (
9
+ Geope,
10
+ )
11
+
12
+ from . import line_searches
13
+ from .line_searches import (
14
+ LineSearch,
15
+ Adam,
16
+ GoldenSection,
17
+ adam,
18
+ golden_section,
19
+ )
20
+
21
+ from .gecko import (
22
+ Gecko,
23
+ )
24
+
25
+ from .grape import (
26
+ Grape,
27
+ )
28
+
29
+ from .parameters import (
30
+ Parameters,
31
+ )
32
+
33
+ from .utils import (
34
+ History,
35
+ )
36
+
37
+ from .lie import (
38
+ Basis,
39
+ Hamiltonian,
40
+ Unitary,
41
+ )
42
+
43
+ from .utils import (
44
+ trace_dot_jit,
45
+ traces,
46
+ check_xy_comb,
47
+ check_Heisenberg_comb,
48
+ check_2_local_comb,
49
+ restriction_function,
50
+ restriction_order_function,
51
+ control_to_indices,
52
+ filter_basis_by_control,
53
+ make_per_element_transform,
54
+ construct_restricted_pauli_basis,
55
+ construct_Heisenberg_pauli_basis,
56
+ construct_two_body_pauli_basis,
57
+ construct_full_pauli_basis,
58
+ creation_annihilation_operators,
59
+ construct_full_spin_boson_basis,
60
+ construct_restricted_spin_boson_basis,
61
+ prepare_random_parameters,
62
+ construct_commuting_ansatz_matrix,
63
+ remove_solution_free_parameters,
64
+ multikron,
65
+ multimatmul,
66
+ multicontrol_unitary,
67
+ qft_unitary,
68
+ golden_section_search_np,
69
+ golden_section_search,
70
+ adam_line_search,
71
+ merge_constraints,
72
+ )
geope/engine.py ADDED
@@ -0,0 +1,622 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import Array
8
+
9
+ jax.config.update("jax_enable_x64", True)
10
+
11
+ from .jax.logm import logm
12
+ from .jax.dexpm import (
13
+ get_Ui_fn,
14
+ get_dexpm,
15
+ get_dexpm_eig,
16
+ get_d2expm,
17
+ get_d2expm_eig,
18
+ )
19
+ from .jax.jacobian import manual_jacobian
20
+ from .jax.hessian import manual_hessian
21
+
22
+ import inspect
23
+ from functools import partial
24
+ from typing import Callable, TYPE_CHECKING
25
+
26
+ if TYPE_CHECKING:
27
+ from .parameters import Parameters
28
+
29
+
30
+ def fidelity(unitary: Array, target_unitary: Array) -> Array:
31
+ """Compute the fidelity between a unitary and a target unitary.
32
+
33
+ The fidelity is defined as the normalised absolute value of the
34
+ Hilbert-Schmidt inner product between the two matrices.
35
+
36
+ Args:
37
+ unitary: The unitary ``Array`` to evaluate.
38
+ target_unitary: The target unitary ``Array``.
39
+
40
+ Returns:
41
+ A scalar fidelity ``Array`` in the range $[0, 1]$.
42
+ """
43
+ return jnp.abs(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
44
+ target_unitary[0]
45
+ )
46
+
47
+
48
+ def get_fidelity_fn(target_unitary: Array) -> Callable[[Array], Array]:
49
+ """Create a partial fidelity function with a fixed target unitary.
50
+
51
+ Args:
52
+ target_unitary: The target unitary ``Array`` to bind.
53
+
54
+ Returns:
55
+ A ``Callable[[Array], Array]`` that accepts a single unitary
56
+ and returns the fidelity against ``target_unitary``.
57
+ """
58
+ return partial(fidelity, target_unitary=target_unitary)
59
+
60
+
61
+ def infidelity(unitary: Array, target_unitary: Array) -> Array:
62
+ """Projective infidelity $1 - F_{\\mathrm{proj}}(U, U_T)$.
63
+
64
+ Args:
65
+ unitary: The unitary ``Array`` to evaluate.
66
+ target_unitary: The target unitary ``Array``.
67
+
68
+ Returns:
69
+ A scalar infidelity ``Array`` in $[0, 1]$.
70
+ """
71
+ return 1 - jnp.abs(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
72
+ target_unitary[0]
73
+ )
74
+
75
+
76
+ def get_infidelity_fn(target_unitary: Array) -> Callable[[Array], Array]:
77
+ """Create a partial projective-infidelity function with a fixed target.
78
+
79
+ Args:
80
+ target_unitary: The target unitary ``Array`` to bind.
81
+
82
+ Returns:
83
+ A ``Callable[[Array], Array]`` returning $1 - F_{\\mathrm{proj}}$.
84
+ """
85
+ return partial(infidelity, target_unitary=target_unitary)
86
+
87
+
88
+ def fidelity_full(unitary: Array, target_unitary: Array) -> Array:
89
+ """Phase-sensitive (non-projective) fidelity.
90
+
91
+ $F_{\\mathrm{full}}(U, U_T) = \\mathrm{Re}\\,\\mathrm{Tr}(U_T^\\dagger U) / d$.
92
+ Unlike the projective fidelity, this is sensitive to a global phase
93
+ on $U$ and lies in $[-1, 1]$.
94
+
95
+ Args:
96
+ unitary: The unitary ``Array`` to evaluate.
97
+ target_unitary: The target unitary ``Array``.
98
+
99
+ Returns:
100
+ A scalar fidelity ``Array`` in $[-1, 1]$.
101
+ """
102
+ return jnp.real(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
103
+ target_unitary[0]
104
+ )
105
+
106
+
107
+ def get_fidelity_full_fn(target_unitary: Array) -> Callable[[Array], Array]:
108
+ """Create a partial phase-sensitive fidelity function with a fixed target.
109
+
110
+ Args:
111
+ target_unitary: The target unitary ``Array`` to bind.
112
+
113
+ Returns:
114
+ A ``Callable[[Array], Array]`` returning $F_{\\mathrm{full}}$.
115
+ """
116
+ return partial(fidelity_full, target_unitary=target_unitary)
117
+
118
+
119
+ def infidelity_full(unitary: Array, target_unitary: Array) -> Array:
120
+ """Phase-sensitive infidelity $1 - F_{\\mathrm{full}}(U, U_T)$.
121
+
122
+ Args:
123
+ unitary: The unitary ``Array`` to evaluate.
124
+ target_unitary: The target unitary ``Array``.
125
+
126
+ Returns:
127
+ A scalar infidelity ``Array`` in $[0, 2]$.
128
+ """
129
+ return 1 - jnp.real(jnp.einsum("ji,ji->", target_unitary.conj(), unitary)) / len(
130
+ target_unitary[0]
131
+ )
132
+
133
+
134
+ def get_infidelity_full_fn(target_unitary: Array) -> Callable[[Array], Array]:
135
+ """Create a partial phase-sensitive infidelity function with a fixed target.
136
+
137
+ Args:
138
+ target_unitary: The target unitary ``Array`` to bind.
139
+
140
+ Returns:
141
+ A ``Callable[[Array], Array]`` returning $1 - F_{\\mathrm{full}}$.
142
+ """
143
+ return partial(infidelity_full, target_unitary=target_unitary)
144
+
145
+
146
+ def compute_matrices_params_list_fn(params_list: Array, basis: Array) -> Array:
147
+ """Compute the product unitary from a list of parameter vectors.
148
+
149
+ For each parameter vector in `params_list`, constructs a Hamiltonian
150
+ as a linear combination of the `basis` elements, exponentiates it,
151
+ and accumulates the product unitary via `jax.lax.scan`.
152
+
153
+ Args:
154
+ params_list: ``Array`` of shape ``(piecewise_steps, K)`` where each row
155
+ contains the Lie-algebra coefficients for one gate segment.
156
+ basis: ``Array`` of shape ``(K, d, d)`` of Hermitian basis matrices.
157
+
158
+ Returns:
159
+ The product unitary ``Array`` of shape ``(d, d)``.
160
+ """
161
+
162
+ def step(U, params):
163
+ A = jnp.tensordot(params, basis, axes=[[-1], [0]])
164
+ Ui = jax.scipy.linalg.expm(1j * A)
165
+ U_new = jnp.matmul(Ui, U)
166
+ return U_new, None
167
+
168
+ U0 = jnp.eye(basis.shape[1], dtype=basis.dtype)
169
+ U_final, _ = jax.lax.scan(step, U0, jnp.stack(params_list))
170
+ return U_final
171
+
172
+
173
+ def get_compute_matrices_params_list_fn(basis: np.ndarray) -> Callable[[Array], Array]:
174
+ """Create a partial unitary-computation function with a fixed basis.
175
+
176
+ Args:
177
+ basis: Array of shape ``(K, d, d)`` of Hermitian basis matrices.
178
+
179
+ Returns:
180
+ A ``Callable[[Array], Array]`` that accepts a parameter list
181
+ and returns the product unitary.
182
+ """
183
+ return partial(compute_matrices_params_list_fn, basis=basis)
184
+
185
+
186
+ def geodesic_hamiltonian(
187
+ unitary: Array,
188
+ target_unitary: Array,
189
+ projective: bool = True,
190
+ key: Array = jax.random.key(0),
191
+ ) -> Array:
192
+ """Compute the geodesic Hamiltonian between a unitary and a target.
193
+
194
+ Computes the generator $g = -i\\log(U^\\dagger U_T) \\in \\mathfrak{u}(d)$
195
+ and returns $U g'$ where $g' = g - \\frac{\\mathrm{Tr}(g)}{d}\\mathbb{1}$
196
+ (the SU part) when ``projective=True``, or $g' = g$ (full U) when
197
+ ``projective=False``.
198
+
199
+ Args:
200
+ unitary: The current unitary ``Array``.
201
+ target_unitary: The target unitary ``Array``.
202
+ projective: If ``True``, subtract the global-phase generator
203
+ (SU geodesic). If ``False``, keep it (U geodesic).
204
+ Defaults to ``True``.
205
+ key: JAX random key forwarded to ``logm``. Defaults to
206
+ ``jax.random.key(0)``.
207
+
208
+ Returns:
209
+ The geodesic tangent ``Array`` $U g'$ at the current unitary.
210
+ """
211
+ g = -1.0j * logm(jnp.einsum("ji,jk->ik", unitary.conj(), target_unitary), key=key)
212
+ if projective:
213
+ Id = jnp.eye(g.shape[0])
214
+ global_phase = jnp.real(jnp.einsum("ij,ji->", Id, g)) / g.shape[0]
215
+ g = g - global_phase * Id
216
+ return unitary @ g
217
+
218
+
219
+ def get_geodesic_hamiltonian_fn(
220
+ target_unitary: Array, projective: bool = True
221
+ ) -> Callable[[Array, Array], Array]:
222
+ """Create a partial geodesic Hamiltonian function with a fixed target.
223
+
224
+ Args:
225
+ target_unitary: The target unitary ``Array`` to bind.
226
+ projective: If ``True``, return the projective (SU) geodesic.
227
+ Defaults to ``True``.
228
+
229
+ Returns:
230
+ A ``Callable[[Array, Array], Array]`` that accepts a unitary and a
231
+ JAX random key and returns the geodesic Hamiltonian.
232
+ """
233
+ return partial(
234
+ geodesic_hamiltonian, target_unitary=target_unitary, projective=projective
235
+ )
236
+
237
+
238
+ def hvp_forward_over_reverse(
239
+ f: Callable[[Array], Array], params: Array, v: Array
240
+ ) -> Array:
241
+ """Compute a Hessian-vector product via forward-over-reverse mode.
242
+
243
+ Args:
244
+ f: Scalar-valued callable of ``params``.
245
+ params: Parameter ``Array`` at which to evaluate.
246
+ v: Tangent ``Array`` for the Hessian-vector product.
247
+
248
+ Returns:
249
+ The Hessian-vector product $\\nabla^2 f \\cdot v$.
250
+ """
251
+ v = v.reshape(params.shape)
252
+ return jax.jvp(jax.grad(f), (params,), (v,))[1]
253
+
254
+
255
+ def get_jacobian_fn(compute_U_fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
256
+ """Build the autodiff Jacobian of the unitary w.r.t. parameters.
257
+
258
+ Returns the holomorphic ``jax.jacobian`` of ``compute_U_fn``. This is the
259
+ live Jacobian path for *all* system sizes: the manual Jacobian
260
+ (``geope.jax.jacobian.get_jacobian_propagator``) exists and is independently
261
+ tested, but is not currently wired into the optimisation pipeline (the
262
+ autodiff path historically overwrote it for the >5-qubit branch — see
263
+ issue #4). The returned function is left un-jitted so it fuses into the
264
+ enclosing ``@jax.jit`` update step on first ``optimize()``.
265
+
266
+ Args:
267
+ compute_U_fn: Callable mapping a parameter list to the product unitary.
268
+
269
+ Returns:
270
+ A ``Callable[[Array], Array]`` returning the Jacobian of the unitary.
271
+ """
272
+ return jax.jacobian(compute_U_fn, argnums=0, holomorphic=True)
273
+
274
+
275
+ def get_gammas_fn(
276
+ compute_U_fn: Callable[[Array], Array],
277
+ geo_fn: Callable[..., Array],
278
+ project_omegas_fn: Callable[[Array], Array],
279
+ ) -> Callable[[Array, Array], Array]:
280
+ """Build the projected geodesic-Hamiltonian (``gammas``) function.
281
+
282
+ Computes the unitary, its geodesic Hamiltonian towards the target, and
283
+ projects that onto the Pauli basis (normalised by the dimension). Returned
284
+ un-jitted so it composes inside an enclosing ``@jax.jit``.
285
+
286
+ Args:
287
+ compute_U_fn: Parameter-list -> unitary.
288
+ geo_fn: ``(unitary, key) -> geodesic Hamiltonian``.
289
+ project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
290
+
291
+ Returns:
292
+ A ``Callable[[Array, Array], Array]`` ``gammas(free_params, key)``.
293
+ """
294
+
295
+ def gammas(free_params: Array, key: Array) -> Array:
296
+ unitary = compute_U_fn(free_params)
297
+ gammaU = geo_fn(unitary, key=key) # seed for logm
298
+ return project_omegas_fn(jnp.expand_dims(gammaU, axis=0)).squeeze(axis=0) / (
299
+ gammaU.shape[0]
300
+ )
301
+
302
+ return gammas
303
+
304
+
305
+ def get_omegas_fn(
306
+ jac_fn: Callable[[Array], Array],
307
+ project_omegas_fn: Callable[[Array], Array],
308
+ proj_indices: np.ndarray,
309
+ has_proj_drift: bool,
310
+ ) -> Callable[[Array], Array]:
311
+ """Build the projected per-gate Jacobian (``omegas``) function.
312
+
313
+ Projects the Jacobian of each gate (w.r.t. each parameter) onto the Pauli
314
+ basis, optionally restricting to the projected indices within the combined
315
+ proj+drift basis. Returned un-jitted so it composes inside an enclosing
316
+ ``@jax.jit``.
317
+
318
+ Args:
319
+ jac_fn: Jacobian of the unitary w.r.t. the free parameters.
320
+ project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
321
+ proj_indices: Projected indices within the proj+drift basis.
322
+ has_proj_drift: Whether the proj+drift basis is non-empty (gates the
323
+ projected-index restriction; mirrors the legacy
324
+ ``np.any(proj_drift_basis)`` check).
325
+
326
+ Returns:
327
+ A ``Callable[[Array], Array]`` ``omegas(free_params)``.
328
+ """
329
+
330
+ def omegas(free_params: Array) -> Array:
331
+ dUs = jnp.array(jac_fn(free_params))
332
+ dUs_t = jnp.transpose(dUs, [2, 3, 0, 1])
333
+ omegas_steps_phis = jnp.array(
334
+ [project_omegas_fn(1.0j * omegaUs) for omegaUs in dUs_t]
335
+ )
336
+ if has_proj_drift:
337
+ omegas_steps_phis = omegas_steps_phis.at[:, proj_indices, :].get()
338
+ return omegas_steps_phis
339
+
340
+ return omegas
341
+
342
+
343
+ def get_gammas_and_omegas_fn(
344
+ compute_U_fn: Callable[[Array], Array],
345
+ jac_fn: Callable[[Array], Array],
346
+ geo_fn: Callable[..., Array],
347
+ project_omegas_fn: Callable[[Array], Array],
348
+ proj_indices: np.ndarray,
349
+ has_proj_drift: bool,
350
+ ) -> Callable[[Array, Array], tuple[Array, Array]]:
351
+ """Build the combined gammas-and-omegas function used by the GEOPE step.
352
+
353
+ Gammas are the projected geodesic Hamiltonian coefficients; omegas encode
354
+ the Jacobian of each gate w.r.t. each parameter, projected onto the Pauli
355
+ basis. This is the single combined body the GEOPE update step calls (one
356
+ ``compute_U_fn`` and one ``jac_fn`` evaluation), matching the legacy
357
+ numerics; :func:`get_gammas_fn` / :func:`get_omegas_fn` are the separately
358
+ testable halves. Returned un-jitted so it fuses into the enclosing
359
+ ``@jax.jit`` update step on first ``optimize()``.
360
+
361
+ Args:
362
+ compute_U_fn: Parameter-list -> unitary.
363
+ jac_fn: Jacobian of the unitary w.r.t. the free parameters.
364
+ geo_fn: ``(unitary, key) -> geodesic Hamiltonian``.
365
+ project_omegas_fn: Projection of matrices onto the Lie-algebra basis.
366
+ proj_indices: Projected indices within the proj+drift basis.
367
+ has_proj_drift: Whether the proj+drift basis is non-empty.
368
+
369
+ Returns:
370
+ A ``Callable[[Array, Array], tuple[Array, Array]]``
371
+ ``gammas_and_omegas(free_params, key) -> (gammaU_params, omegas)``.
372
+ """
373
+
374
+ def gammas_and_omegas(free_params: Array, key: Array) -> tuple[Array, Array]:
375
+ unitary = compute_U_fn(free_params)
376
+ gammaU = geo_fn(unitary, key=key) # seed for logm
377
+ gammaU_params = project_omegas_fn(jnp.expand_dims(gammaU, axis=0)).squeeze(
378
+ axis=0
379
+ ) / (gammaU.shape[0])
380
+
381
+ dUs = jnp.array(jac_fn(free_params))
382
+ dUs_t = jnp.transpose(dUs, [2, 3, 0, 1])
383
+ omegas_steps_phis = jnp.array(
384
+ [project_omegas_fn(1.0j * omegaUs) for omegaUs in dUs_t]
385
+ )
386
+
387
+ if has_proj_drift:
388
+ omegas_steps_phis = omegas_steps_phis.at[:, proj_indices, :].get()
389
+
390
+ return gammaU_params, omegas_steps_phis
391
+
392
+ return gammas_and_omegas
393
+
394
+
395
+ def get_hessian_fn(infid_fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
396
+ """Build the full Hessian function via forward-over-reverse HVPs.
397
+
398
+ Materialises the Hessian of ``infid_fn`` by mapping a Hessian-vector
399
+ product over the identity matrix's columns. Returned un-jitted so it fuses
400
+ into the enclosing ``@jax.jit`` update step.
401
+
402
+ Args:
403
+ infid_fn: Scalar-valued infidelity callable of the free parameters.
404
+
405
+ Returns:
406
+ A ``Callable[[Array], Array]`` ``hess(y)`` returning the Hessian.
407
+ """
408
+
409
+ def hess(y: Array) -> Array:
410
+ return jax.vmap(lambda x: hvp_forward_over_reverse(infid_fn, y, x))(
411
+ jnp.eye(y.size, dtype=y.dtype)
412
+ )
413
+
414
+ return hess
415
+
416
+
417
+ def get_hessian_propagator_fn(
418
+ basis: np.ndarray,
419
+ target: Array,
420
+ projective: bool = True,
421
+ method: str = "eig",
422
+ hermitian: bool = True,
423
+ ) -> Callable[[Array], Array]:
424
+ r"""Build the infidelity Hessian manually (Goodwin–Kuprov NR-GRAPE).
425
+
426
+ Analytic drop-in for `get_hessian_fn`: returns ``hess(y) -> (P, P)`` with
427
+ ``P = y.size``, the Hessian of the same infidelity that `get_hessian_fn`
428
+ differentiates by autodiff, but built from the manual propagator
429
+ derivatives (`manual_jacobian`, `manual_hessian`) rather than from
430
+ forward-over-reverse HVPs.
431
+
432
+ Let $z = \mathrm{Tr}(U_T^\dagger U)$, $\partial_a z$, $\partial_a\partial_b z$
433
+ be obtained by contracting $U$, $\partial U$, $\partial^2 U$ against
434
+ $U_T^\dagger$. For the phase-sensitive cost $C = 1 - \mathrm{Re}(z)/d$ the
435
+ Hessian is the linear contraction $-\mathrm{Re}(\partial_a\partial_b z)/d$;
436
+ for the projective cost $C = 1 - |z|/d$,
437
+
438
+ $$\partial_a\partial_b|z| =
439
+ \frac{\mathrm{Re}(\overline{\partial_a z}\,\partial_b z)
440
+ + \mathrm{Re}(\bar z\,\partial_a\partial_b z)}{|z|}
441
+ - \frac{\mathrm{Re}(\bar z\,\partial_a z)\,
442
+ \mathrm{Re}(\bar z\,\partial_b z)}{|z|^3}.$$
443
+
444
+ Like the projective fidelity itself, this is singular as $|z| \to 0$ (the
445
+ near-identity / traceless-target gotcha) — the autodiff Hessian shares that.
446
+
447
+ Memory note: this materialises the dense propagator Hessian
448
+ (`manual_hessian`, $O(G^2 d^2 K^2)$); intended for the small systems where
449
+ NR-GRAPE is used.
450
+
451
+ Args:
452
+ basis: Proj+drift basis ``(K, d, d)`` — the same basis the bound
453
+ ``compute_U_fn`` uses.
454
+ target: Target unitary ``(d, d)``.
455
+ projective: Match the projective (``True``) or phase-sensitive
456
+ (``False``) infidelity.
457
+ method: ``"eig"`` (spectral, default) or ``"block"`` (auxiliary-matrix)
458
+ per-step derivatives.
459
+ hermitian: For ``method="eig"``, assume real parameters (skew-Hermitian
460
+ generators) and use the faster ``eigh``-based spectral derivatives.
461
+ Set ``False`` for complex-valued parameters.
462
+
463
+ Returns:
464
+ A ``Callable[[Array], Array]`` ``hess(y)`` returning the ``(P, P)``
465
+ infidelity Hessian. Left un-jitted so it fuses into the enclosing
466
+ ``@jax.jit`` update step.
467
+ """
468
+ Ui_fn = get_Ui_fn(basis)
469
+ if method == "eig":
470
+ jac_step = get_dexpm_eig(basis, hermitian=hermitian)
471
+ hess_step = get_d2expm_eig(basis, hermitian=hermitian)
472
+ elif method == "block":
473
+ jac_step = get_dexpm(basis)
474
+ hess_step = get_d2expm(basis)
475
+ else:
476
+ raise ValueError(f"Unknown method {method!r}; expected 'eig' or 'block'.")
477
+
478
+ compute_U = get_compute_matrices_params_list_fn(basis)
479
+ t_conj = jnp.asarray(target).conj()
480
+ d = jnp.asarray(target).shape[0]
481
+
482
+ def hess(y: Array) -> Array:
483
+ U = compute_U(y)
484
+ dU = manual_jacobian(y, Ui_fn, jac_step) # (G, d, d, K)
485
+ H = manual_hessian(y, Ui_fn, jac_step, hess_step) # (G, G, d, d, K, K)
486
+
487
+ # Contract the propagator and its derivatives with U_T^dagger.
488
+ z = jnp.einsum("ab,ab->", t_conj, U)
489
+ dz = jnp.einsum("ab,iabk->ik", t_conj, dU) # (G, K)
490
+ d2z = jnp.einsum("ab,ijabkl->ijkl", t_conj, H) # (G, G, K, K)
491
+
492
+ n_g, n_k = y.shape
493
+ P = n_g * n_k
494
+ dz_f = dz.reshape(P)
495
+ d2z_f = jnp.transpose(d2z, (0, 2, 1, 3)).reshape(P, P)
496
+
497
+ if not projective:
498
+ return -jnp.real(d2z_f) / d
499
+
500
+ r = jnp.abs(z)
501
+ z_bar = jnp.conj(z)
502
+ re_zdz = jnp.real(z_bar * dz_f) # (P,)
503
+ term1 = (
504
+ jnp.real(jnp.outer(jnp.conj(dz_f), dz_f)) + jnp.real(z_bar * d2z_f)
505
+ ) / r
506
+ term2 = jnp.outer(re_zdz, re_zdz) / r**3
507
+ return -(term1 - term2) / d
508
+
509
+ return hess
510
+
511
+
512
+ def wrap_compute_U_param_transform(
513
+ params: "Parameters", raw_compute_U: Callable[[Array], Array]
514
+ ) -> Callable[[Array], Array]:
515
+ """Wrap ``compute_U`` to honour ``params.param_transform``.
516
+
517
+ The user-facing experimental parameters $\\phi^{\\mathrm{exp}}$ are mapped to
518
+ projected-basis coefficients via ``params.param_transform`` (possibly
519
+ step-dependent), embedded into the proj+drift basis, and combined with the
520
+ drift before the original ``raw_compute_U`` is called.
521
+
522
+ Returned un-jitted so it fuses into the enclosing ``@jax.jit`` update step
523
+ on first ``optimize()``.
524
+
525
+ Args:
526
+ params: The ``Parameters`` object carrying ``param_transform``.
527
+ raw_compute_U: The projected-basis unitary-computation function.
528
+
529
+ Returns:
530
+ The wrapped experimental-space ``compute_U`` callable.
531
+ """
532
+ n_exp = params.n_experimental_params
533
+ n_proj_drift = params.proj_drift_basis.lie_algebra_dim
534
+ proj_idx_pd = params.proj_indices_projdrift_basis
535
+ drift_idx_pd = params.drift_indices_projdrift_basis
536
+
537
+ # Detect step-dependence: tau(phi) vs tau(phi, step_index)
538
+ _step_dependent = len(inspect.signature(params.param_transform).parameters) >= 2
539
+
540
+ # Detect whether transform outputs full-basis or projected-basis coefficients
541
+ _test_out = (
542
+ params.param_transform(jnp.zeros(n_exp), 0)
543
+ if _step_dependent
544
+ else params.param_transform(jnp.zeros(n_exp))
545
+ )
546
+ tf_out_dim = _test_out.shape[0]
547
+ n_proj = params.projected_basis.lie_algebra_dim
548
+ if tf_out_dim != n_proj:
549
+ _extract = jnp.array(
550
+ np.where(np.array(params.projected_basis.overlap(params.basis)))[0]
551
+ )
552
+ else:
553
+ _extract = None
554
+
555
+ if params.drift_parameters is not None:
556
+ _drift = jnp.array(params.drift_parameters, dtype=jnp.float64)
557
+ else:
558
+ _drift = None
559
+
560
+ def _wrapped_compute_U(
561
+ exp_params,
562
+ _raw=raw_compute_U,
563
+ _tf=params.param_transform,
564
+ _pi=proj_idx_pd,
565
+ _di=drift_idx_pd,
566
+ _npd=n_proj_drift,
567
+ _dr=_drift,
568
+ _ext=_extract,
569
+ _step_dep=_step_dependent,
570
+ ):
571
+ if _step_dep:
572
+ ctrl = jax.vmap(_tf)(exp_params, jnp.arange(exp_params.shape[0]))
573
+ else:
574
+ ctrl = jax.vmap(_tf)(exp_params)
575
+ if _ext is not None:
576
+ ctrl = ctrl[:, _ext]
577
+ # Promote dtype so complex tracing through real intermediates works
578
+ _dtype = jnp.result_type(ctrl.dtype, exp_params.dtype)
579
+ ctrl = ctrl.astype(_dtype)
580
+ full = jnp.zeros((exp_params.shape[0], _npd), dtype=_dtype)
581
+ full = full.at[:, _pi].set(ctrl)
582
+ if _dr is not None:
583
+ full = full.at[:, _di].set(
584
+ jnp.broadcast_to(
585
+ _dr.astype(_dtype), (exp_params.shape[0], _dr.shape[0])
586
+ )
587
+ )
588
+ return _raw(full)
589
+
590
+ return _wrapped_compute_U
591
+
592
+
593
+ def get_split_jacobian_fn(
594
+ compute_U_fn: Callable[[Array], Array],
595
+ ) -> Callable[[Array], Array]:
596
+ """Build a real/imag-split Jacobian of ``compute_U_fn``.
597
+
598
+ Used on the ``param_transform`` path: differentiating through the
599
+ real-valued user transform with a holomorphic Jacobian would discard the
600
+ imaginary part of intermediates, so the unitary is split into real and
601
+ imaginary parts, each differentiated, then recombined.
602
+
603
+ Returned un-jitted so it fuses into the enclosing ``@jax.jit`` update step.
604
+
605
+ Args:
606
+ compute_U_fn: The (wrapped) experimental-space unitary function.
607
+
608
+ Returns:
609
+ A ``Callable[[Array], Array]`` returning the complex Jacobian.
610
+ """
611
+
612
+ def _split_U(x):
613
+ U = compute_U_fn(x)
614
+ return jnp.stack([jnp.real(U), jnp.imag(U)])
615
+
616
+ _raw_jac_split = jax.jacobian(_split_U, argnums=0)
617
+
618
+ def _jac_fn(x):
619
+ jac_split = _raw_jac_split(x)
620
+ return jac_split[0] + 1j * jac_split[1]
621
+
622
+ return _jac_fn