tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,896 @@
1
+ """
2
+ Analog time evolution engines
3
+ """
4
+
5
+ from typing import Any, Tuple, Optional, Callable, List, Sequence, Dict
6
+ from functools import partial
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+ from .cons import backend, dtypestr, rdtypestr, contractor
12
+ from .gates import Gate
13
+ from .utils import arg_alias
14
+
15
+ Tensor = Any
16
+ Circuit = Any
17
+
18
+
19
+ def lanczos_iteration_scan(
20
+ hamiltonian: Any, initial_vector: Any, subspace_dimension: int
21
+ ) -> Tuple[Any, Any]:
22
+ """
23
+ Use Lanczos algorithm to construct orthogonal basis and projected Hamiltonian
24
+ of Krylov subspace, using `tc.backend.scan` for JIT compatibility.
25
+
26
+ :param hamiltonian: Sparse or dense Hamiltonian matrix
27
+ :type hamiltonian: Tensor
28
+ :param initial_vector: Initial quantum state vector
29
+ :type initial_vector: Tensor
30
+ :param subspace_dimension: Dimension of Krylov subspace
31
+ :type subspace_dimension: int
32
+ :return: Tuple containing (basis matrix, projected Hamiltonian)
33
+ :rtype: Tuple[Tensor, Tensor]
34
+ """
35
+ state_size = backend.shape_tuple(initial_vector)[0]
36
+ if backend.is_sparse(hamiltonian):
37
+ hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
38
+
39
+ # Main scan body for the outer loop (iterating j)
40
+ def lanczos_step(carry: Tuple[Any, ...], j: int) -> Tuple[Any, ...]:
41
+ v, basis, alphas, betas = carry
42
+
43
+ if backend.is_sparse(hamiltonian):
44
+ w = backend.sparse_dense_matmul(hamiltonian, v)
45
+ else:
46
+ w = backend.matvec(hamiltonian, v)
47
+
48
+ alpha = backend.real(backend.sum(backend.conj(v) * w))
49
+ w = w - backend.cast(alpha, dtypestr) * v
50
+
51
+ # Inner scan for re-orthogonalization (iterating k)
52
+ # def ortho_step(inner_carry: Tuple[Any, Any], k: int) -> Tuple[Any, Any]:
53
+ # w_carry, j_val = inner_carry
54
+
55
+ # def do_projection() -> Any:
56
+ # # `basis` is available here through closure
57
+ # v_k = basis[:, k]
58
+ # projection = backend.sum(backend.conj(v_k) * w_carry)
59
+ # return w_carry - projection * v_k
60
+
61
+ # def do_nothing() -> Any:
62
+ # return w_carry
63
+
64
+ # # Orthogonalize against v_0, ..., v_j
65
+ # w_new = backend.cond(k <= j_val, do_projection, do_nothing)
66
+ # return (w_new, j_val) # Return the new carry for the inner loop
67
+
68
+ # # Pass `j` into the inner scan's carry
69
+ # inner_init_carry = (w, j)
70
+ # final_inner_carry = backend.scan(
71
+ # ortho_step, backend.arange(subspace_dimension), inner_init_carry
72
+ # )
73
+ # w_ortho = final_inner_carry[0]
74
+
75
+ def ortho_step(w_carry: Any, elems_tuple: Tuple[Any, Any]) -> Any:
76
+ k, j_from_elems = elems_tuple
77
+
78
+ def do_projection() -> Any:
79
+ v_k = basis[:, k]
80
+ projection = backend.sum(backend.conj(v_k) * w_carry)
81
+ return w_carry - projection * v_k
82
+
83
+ def do_nothing() -> Any:
84
+ return backend.cast(w_carry, dtype=dtypestr)
85
+
86
+ w_new = backend.cond(k <= j_from_elems, do_projection, do_nothing)
87
+ return w_new
88
+
89
+ k_elems = backend.arange(subspace_dimension)
90
+ j_elems = backend.tile(backend.reshape(j, [1]), [subspace_dimension])
91
+ inner_elems = (k_elems, j_elems)
92
+ w_ortho = backend.scan(ortho_step, inner_elems, w)
93
+
94
+ beta = backend.norm(w_ortho)
95
+ beta = backend.real(beta)
96
+
97
+ # Update alphas and betas arrays
98
+ new_alphas = backend.scatter(
99
+ alphas, backend.reshape(j, [1, 1]), backend.reshape(alpha, [1])
100
+ )
101
+ new_betas = backend.scatter(
102
+ betas, backend.reshape(j, [1, 1]), backend.reshape(beta, [1])
103
+ )
104
+
105
+ def update_state_fn() -> Tuple[Any, Any]:
106
+ epsilon = 1e-15
107
+ next_v = w_ortho / backend.cast(beta + epsilon, dtypestr)
108
+
109
+ one_hot_update = backend.onehot(j + 1, subspace_dimension)
110
+ one_hot_update = backend.cast(one_hot_update, dtype=dtypestr)
111
+
112
+ # Create a mask to update only the (j+1)-th column
113
+ mask = 1.0 - backend.reshape(one_hot_update, [1, subspace_dimension])
114
+ new_basis = basis * mask + backend.reshape(
115
+ next_v, [-1, 1]
116
+ ) * backend.reshape(one_hot_update, [1, subspace_dimension])
117
+
118
+ return next_v, new_basis
119
+
120
+ def keep_state_fn() -> Tuple[Any, Any]:
121
+ return v, basis
122
+
123
+ next_v_carry, new_basis = backend.cond(
124
+ j < subspace_dimension - 1, update_state_fn, keep_state_fn
125
+ )
126
+
127
+ return (next_v_carry, new_basis, new_alphas, new_betas)
128
+
129
+ # Prepare initial state for the main scan
130
+ v0 = initial_vector / backend.norm(initial_vector)
131
+
132
+ init_basis = backend.zeros((state_size, subspace_dimension), dtype=dtypestr)
133
+ init_alphas = backend.zeros((subspace_dimension,), dtype=rdtypestr)
134
+ init_betas = backend.zeros((subspace_dimension,), dtype=rdtypestr)
135
+
136
+ one_hot_0 = backend.onehot(0, subspace_dimension)
137
+ one_hot_0 = backend.cast(one_hot_0, dtype=dtypestr)
138
+ init_basis = init_basis + backend.reshape(v0, [-1, 1]) * backend.reshape(
139
+ one_hot_0, [1, subspace_dimension]
140
+ )
141
+
142
+ init_carry = (v0, init_basis, init_alphas, init_betas)
143
+
144
+ # Run the main scan
145
+ final_carry = backend.scan(
146
+ lanczos_step, backend.arange(subspace_dimension), init_carry
147
+ )
148
+ basis_matrix, alphas_tensor, betas_tensor = (
149
+ final_carry[1],
150
+ final_carry[2],
151
+ final_carry[3],
152
+ )
153
+
154
+ betas_off_diag = betas_tensor[:-1]
155
+
156
+ diag_part = backend.diagflat(alphas_tensor)
157
+ if backend.shape_tuple(betas_off_diag)[0] > 0:
158
+ off_diag_part = backend.diagflat(betas_off_diag, k=1)
159
+ projected_hamiltonian = (
160
+ diag_part + off_diag_part + backend.conj(backend.transpose(off_diag_part))
161
+ )
162
+ else:
163
+ projected_hamiltonian = diag_part
164
+
165
+ return basis_matrix, projected_hamiltonian
166
+
167
+
168
+ def lanczos_iteration(
169
+ hamiltonian: Tensor, initial_vector: Tensor, subspace_dimension: int
170
+ ) -> Tuple[Tensor, Tensor]:
171
+ """
172
+ Use Lanczos algorithm to construct orthogonal basis and projected Hamiltonian
173
+ of Krylov subspace.
174
+
175
+ :param hamiltonian: Sparse or dense Hamiltonian matrix
176
+ :type hamiltonian: Tensor
177
+ :param initial_vector: Initial quantum state vector
178
+ :type initial_vector: Tensor
179
+ :param subspace_dimension: Dimension of Krylov subspace
180
+ :type subspace_dimension: int
181
+ :return: Tuple containing (basis matrix, projected Hamiltonian)
182
+ :rtype: Tuple[Tensor, Tensor]
183
+ """
184
+ # Initialize
185
+ vector = initial_vector
186
+ vector = backend.cast(vector, dtypestr)
187
+
188
+ # Use list to store basis vectors
189
+ basis_vectors: List[Any] = []
190
+
191
+ # Store alpha and beta coefficients for constructing tridiagonal matrix
192
+ alphas = []
193
+ betas = []
194
+
195
+ # Normalize initial vector
196
+ vector_norm = backend.norm(vector)
197
+ vector = vector / vector_norm
198
+
199
+ # Add first basis vector
200
+ basis_vectors.append(vector)
201
+
202
+ if backend.is_sparse(hamiltonian):
203
+ hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
204
+
205
+ # Lanczos iteration (fixed number of iterations for JIT compatibility)
206
+ for j in range(subspace_dimension):
207
+ # Calculate H|v_j>
208
+ if backend.is_sparse(hamiltonian):
209
+ w = backend.sparse_dense_matmul(hamiltonian, vector)
210
+ else:
211
+ w = backend.matvec(hamiltonian, vector)
212
+
213
+ # Calculate alpha_j = <v_j|H|v_j>
214
+ alpha = backend.real(backend.sum(backend.conj(vector) * w))
215
+ alphas.append(alpha)
216
+
217
+ # w = H|v_j> - alpha_j|v_j> - beta_{j-1}|v_{j-1}>
218
+ # is not sufficient, require re-normalization
219
+ w = w - backend.cast(alpha, dtypestr) * vector
220
+
221
+ for k in range(j + 1):
222
+ v_k = basis_vectors[k]
223
+ projection = backend.sum(backend.conj(v_k) * w)
224
+ w = w - projection * v_k
225
+
226
+ # if j > 0:
227
+ # w = w - prev_beta * basis_vectors[-2]
228
+
229
+ # Calculate beta_{j+1} = ||w||
230
+ beta = backend.norm(w)
231
+ betas.append(beta)
232
+
233
+ # Use regularization technique to avoid division by zero error,
234
+ # adding small epsilon value to ensure numerical stability
235
+ epsilon = 1e-15
236
+ norm_factor = 1.0 / (beta + epsilon)
237
+
238
+ # Normalize w to get |v_{j+1}> (except for the last iteration)
239
+ if j < subspace_dimension - 1:
240
+ vector = w * backend.cast(norm_factor, dtypestr)
241
+ basis_vectors.append(vector)
242
+
243
+ # Construct final basis matrix
244
+ basis_matrix = backend.stack(basis_vectors, axis=1)
245
+
246
+ # Construct tridiagonal projected Hamiltonian
247
+ # Use vectorized method to construct tridiagonal matrix at once
248
+ alphas_tensor = backend.stack(alphas)
249
+ # Only use first krylov_dim-1 beta values to construct off-diagonal
250
+ betas_tensor = backend.stack(betas[:-1]) if len(betas) > 1 else backend.stack([])
251
+
252
+ # Convert to correct data type
253
+ alphas_tensor = backend.cast(alphas_tensor, dtype=dtypestr)
254
+ if len(betas_tensor) > 0:
255
+ betas_tensor = backend.cast(betas_tensor, dtype=dtypestr)
256
+
257
+ # Construct diagonal and off-diagonal parts
258
+ diag_part = backend.diagflat(alphas_tensor)
259
+ if len(betas_tensor) > 0:
260
+ off_diag_part = backend.diagflat(betas_tensor, k=1)
261
+ projected_hamiltonian = (
262
+ diag_part + off_diag_part + backend.transpose(off_diag_part)
263
+ )
264
+ else:
265
+ projected_hamiltonian = diag_part
266
+
267
+ return basis_matrix, projected_hamiltonian
268
+
269
+
270
+ def krylov_evol(
271
+ hamiltonian: Tensor,
272
+ initial_state: Tensor,
273
+ times: Tensor,
274
+ subspace_dimension: int,
275
+ callback: Optional[Callable[[Any], Any]] = None,
276
+ scan_impl: bool = False,
277
+ ) -> Any:
278
+ """
279
+ Perform quantum state time evolution using Krylov subspace method.
280
+
281
+ :param hamiltonian: Sparse or dense Hamiltonian matrix
282
+ :type hamiltonian: Tensor
283
+ :param initial_state: Initial quantum state
284
+ :type initial_state: Tensor
285
+ :param times: List of time points
286
+ :type times: Tensor
287
+ :param subspace_dimension: Krylov subspace dimension
288
+ :type subspace_dimension: int
289
+ :param callback: Optional callback function applied to quantum state at
290
+ each evolution time point, return some observables
291
+ :type callback: Optional[Callable[[Any], Any]], optional
292
+ :param scan_impl: whether use scan implementation, suitable for jit but may be slow on numpy
293
+ defaults False, True not work for tensorflow backend + jit, due to stupid issue of tensorflow
294
+ context separation and the notorious inaccesibletensor error
295
+ :type scan_impl: bool, optional
296
+ :return: List of evolved quantum states, or list of callback function results
297
+ (if callback provided)
298
+ :rtype: Any
299
+ """
300
+ # TODO(@refraction-ray): stable and efficient AD is to be investigated
301
+ if not scan_impl:
302
+ basis_matrix, projected_hamiltonian = lanczos_iteration(
303
+ hamiltonian, initial_state, subspace_dimension
304
+ )
305
+ else:
306
+ basis_matrix, projected_hamiltonian = lanczos_iteration_scan(
307
+ hamiltonian, initial_state, subspace_dimension
308
+ )
309
+ initial_state = backend.cast(initial_state, dtypestr)
310
+ # Project initial state to Krylov subspace: |psi_proj> = V_m^† |psi(0)>
311
+ projected_state = backend.matvec(
312
+ backend.conj(backend.transpose(basis_matrix)), initial_state
313
+ )
314
+
315
+ # Perform spectral decomposition of projected Hamiltonian: T_m = U D U^†
316
+ eigenvalues, eigenvectors = backend.eigh(projected_hamiltonian)
317
+ eigenvalues = backend.cast(eigenvalues, dtypestr)
318
+ eigenvectors = backend.cast(eigenvectors, dtypestr)
319
+ times = backend.convert_to_tensor(times)
320
+ times = backend.cast(times, dtypestr)
321
+
322
+ # Transform projected state to eigenbasis: |psi_coeff> = U^† |psi_proj>
323
+ eigenvectors_projected_state = backend.matvec(
324
+ backend.conj(backend.transpose(eigenvectors)), projected_state
325
+ )
326
+
327
+ # Calculate exp(-i*projected_H*t) * projected_state
328
+ results = []
329
+ for t in times:
330
+ # Calculate exp(-i*eigenvalues*t)
331
+ exp_diagonal = backend.exp(-1j * eigenvalues * t)
332
+
333
+ # Evolve state in eigenbasis: |psi_evolved_coeff> = exp(-i*D*t) |psi_coeff>
334
+ evolved_projected_coeff = exp_diagonal * eigenvectors_projected_state
335
+
336
+ # Transform back to eigenbasis: |psi_evolved_proj> = U |psi_evolved_coeff>
337
+ evolved_projected = backend.matvec(eigenvectors, evolved_projected_coeff)
338
+
339
+ # Transform back to original basis: |psi(t)> = V_m |psi_evolved_proj>
340
+ evolved_state = backend.matvec(basis_matrix, evolved_projected)
341
+
342
+ # Apply callback function if provided
343
+ if callback is not None:
344
+ result = callback(evolved_state)
345
+ else:
346
+ result = evolved_state
347
+
348
+ results.append(result)
349
+
350
+ return backend.stack(results)
351
+
352
+
353
+ @partial(
354
+ arg_alias,
355
+ alias_dict={"h": ["hamiltonian"], "psi0": ["initial_state"], "tlist": ["times"]},
356
+ )
357
+ def hamiltonian_evol(
358
+ h: Tensor,
359
+ psi0: Tensor,
360
+ tlist: Tensor,
361
+ callback: Optional[Callable[..., Any]] = None,
362
+ ) -> Tensor:
363
+ """
364
+ Fast implementation of time independent Hamiltonian evolution using eigendecomposition.
365
+ By default, performs imaginary time evolution.
366
+
367
+ :param h: Time-independent Hamiltonian matrix
368
+ :type h: Tensor
369
+ :param psi0: Initial state vector
370
+ :type psi0: Tensor
371
+ :param tlist: Time points for evolution
372
+ :type tlist: Tensor
373
+ :param callback: Optional function to process state at each time point
374
+ :type callback: Optional[Callable[..., Any]], optional
375
+ :return: Evolution results at each time point. If callback is None, returns state vectors;
376
+ otherwise returns callback results
377
+ :rtype: Tensor
378
+
379
+ :Example:
380
+
381
+ >>> import tensorcircuit as tc
382
+ >>> import numpy as np
383
+ >>> # Define a simple 2-qubit Hamiltonian
384
+ >>> h = tc.array_to_tensor([
385
+ ... [1.0, 0.0, 0.0, 0.0],
386
+ ... [0.0, -1.0, 2.0, 0.0],
387
+ ... [0.0, 2.0, -1.0, 0.0],
388
+ ... [0.0, 0.0, 0.0, 1.0]
389
+ ... ])
390
+ >>> # Initial state |00>
391
+ >>> psi0 = tc.array_to_tensor([1.0, 0.0, 0.0, 0.0])
392
+ >>> # Evolution times
393
+ >>> times = tc.array_to_tensor([0.0, 0.5, 1.0])
394
+ >>> # Evolve and get states
395
+ >>> states = tc.experimental.hamiltonian_evol(times, h, psi0)
396
+ >>> print(states.shape) # (3, 4)
397
+
398
+
399
+ Note:
400
+ 1. The Hamiltonian must be time-independent
401
+ 2. For time-dependent Hamiltonians, use ``evol_local`` or ``evol_global`` instead
402
+ 3. The evolution is performed in imaginary time by default (factor -t in exponential)
403
+ 4. The state is automatically normalized at each time point
404
+ """
405
+ psi0 = backend.cast(psi0, dtypestr)
406
+ es, u = backend.eigh(h)
407
+ u = backend.cast(u, dtypestr)
408
+ utpsi0 = backend.convert_to_tensor(
409
+ backend.transpose(u) @ backend.reshape(psi0, [-1, 1])
410
+ ) # in case np.matrix...
411
+ utpsi0 = backend.reshape(utpsi0, [-1])
412
+ es = backend.cast(es, dtypestr)
413
+ tlist = backend.cast(backend.convert_to_tensor(tlist), dtypestr)
414
+
415
+ @backend.jit
416
+ def _evol(t: Tensor) -> Tensor:
417
+ ebetah_utpsi0 = backend.exp(-t * es) * utpsi0
418
+ psi_exact = backend.conj(u) @ backend.reshape(ebetah_utpsi0, [-1, 1])
419
+ psi_exact = backend.reshape(psi_exact, [-1])
420
+ psi_exact = psi_exact / backend.norm(psi_exact)
421
+ if callback is None:
422
+ return psi_exact
423
+ return callback(psi_exact)
424
+
425
+ return backend.stack([_evol(t) for t in tlist])
426
+
427
+
428
+ ed_evol = hamiltonian_evol
429
+
430
+
431
+ def _solve_ode(
432
+ f: Callable[..., Tensor],
433
+ s: Tensor,
434
+ times: Tensor,
435
+ args: Any,
436
+ solver_kws: Dict[str, Any],
437
+ ) -> Tensor:
438
+ rtol = solver_kws.get("rtol", 1e-8)
439
+ atol = solver_kws.get("atol", 1e-8)
440
+ ode_backend = solver_kws.get("ode_backend", "jaxode")
441
+ max_steps = solver_kws.get("max_steps", 4096)
442
+
443
+ ts = backend.convert_to_tensor(times)
444
+ ts = backend.cast(ts, dtype=rdtypestr)
445
+
446
+ if ode_backend == "jaxode":
447
+ from jax.experimental.ode import odeint
448
+
449
+ s1 = odeint(f, s, ts, rtol=rtol, atol=atol, mxstep=max_steps, *args)
450
+ return s1
451
+
452
+ import diffrax
453
+
454
+ # Ignore complex warning
455
+ warnings.simplefilter("ignore", category=UserWarning, append=True)
456
+
457
+ solver = solver_kws.get("solver", "Tsit5")
458
+ dt0 = solver_kws.get("dt0", 0.01)
459
+ all_solvers = {
460
+ "Dopri5": diffrax.Dopri5,
461
+ "Tsit5": diffrax.Tsit5,
462
+ "Dopri8": diffrax.Dopri8,
463
+ "Kvaerno5": diffrax.Kvaerno5,
464
+ }
465
+
466
+ # ODE
467
+ term = diffrax.ODETerm(lambda t, y, args: f(y, t, *args))
468
+
469
+ # solve ODE
470
+ s1 = diffrax.diffeqsolve(
471
+ terms=term,
472
+ solver=all_solvers[solver](),
473
+ t0=times[0],
474
+ t1=times[-1],
475
+ dt0=dt0,
476
+ y0=s,
477
+ saveat=diffrax.SaveAt(ts=times),
478
+ args=args,
479
+ stepsize_controller=diffrax.PIDController(rtol=rtol, atol=atol),
480
+ max_steps=max_steps,
481
+ ).ys
482
+ return s1
483
+
484
+
485
+ def ode_evol_local(
486
+ hamiltonian: Callable[..., Tensor],
487
+ initial_state: Tensor,
488
+ times: Tensor,
489
+ index: Sequence[int],
490
+ callback: Optional[Callable[..., Tensor]] = None,
491
+ *args: Any,
492
+ **solver_kws: Any,
493
+ ) -> Tensor:
494
+ """
495
+ ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.
496
+ This function solves the time-dependent Schrodinger equation using numerical ODE integration.
497
+ The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
498
+ The ode_backend parameter defaults to 'jaxode' (which uses ``jax.experimental.ode.odeint`` with a default solver
499
+ of 'Dopri5'); if set to 'diffrax', it uses ``diffrax.diffeqsolve`` instead (with a default solver of 'Tsit5').
500
+
501
+ Note: This function currently only supports the JAX backend.
502
+
503
+ :param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified
504
+ subsystem size. The function signature should be ``hamiltonian(time, *args) -> Tensor``.
505
+ :type hamiltonian: Callable[..., Tensor]
506
+ :param initial_state: The initial quantum state vector of the full system.
507
+ :type initial_state: Tensor
508
+ :param times: Time points for which to compute the evolution. Should be a 1D array of times.
509
+ :type times: Tensor
510
+ :param index: Indices of qubits where the Hamiltonian is applied.
511
+ :type index: Sequence[int]
512
+ :param callback: Optional function to apply to the state at each time step.
513
+ :type callback: Optional[Callable[..., Tensor]]
514
+ :param args: Additional arguments to pass to the Hamiltonian function.
515
+ :param solver_kws: Additional keyword arguments to pass to the ODE solver.
516
+
517
+ - ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
518
+ uses ``diffrax.diffeqsolve``.
519
+
520
+ - ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
521
+ like the numerical approximation to your equation.
522
+
523
+ - The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
524
+ and only works when ``ode_backend='diffrax'``.
525
+
526
+ - ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.
527
+
528
+ - ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
529
+ unconditionally and only works when ``ode_backend='diffrax'``.
530
+ :type solver_kws: dict
531
+
532
+ :return: Evolved quantum states at the specified time points. If callback is provided,
533
+ returns the callback results; otherwise returns the state vectors.
534
+ :rtype: Tensor
535
+ """
536
+
537
+ n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7)
538
+ l = len(index)
539
+
540
+ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
541
+ y = backend.reshape2(y)
542
+ y = Gate(y)
543
+ h = -1.0j * hamiltonian(t, *args)
544
+ if backend.is_sparse(h):
545
+ h = backend.to_dense(h)
546
+ h = backend.reshape2(h)
547
+ h = Gate(h)
548
+ edges = []
549
+ for i in range(n):
550
+ if i not in index:
551
+ edges.append(y[i])
552
+ else:
553
+ j = index.index(i)
554
+ edges.append(h[j])
555
+ h[j + l] ^ y[i]
556
+ y = contractor([y, h], output_edge_order=edges)
557
+ return backend.reshape(y.tensor, [-1])
558
+
559
+ s1 = _solve_ode(f, initial_state, times, args, solver_kws)
560
+
561
+ if callback is None:
562
+ return s1
563
+ return backend.stack([callback(a_state) for a_state in s1])
564
+
565
+
566
+ def ode_evol_global(
567
+ hamiltonian: Callable[..., Tensor],
568
+ initial_state: Tensor,
569
+ times: Tensor,
570
+ callback: Optional[Callable[..., Tensor]] = None,
571
+ *args: Any,
572
+ **solver_kws: Any,
573
+ ) -> Tensor:
574
+ """
575
+ ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.
576
+ This function solves the time-dependent Schrodinger equation using numerical ODE integration.
577
+ The Hamiltonian is applied to the full system and should be provided in sparse matrix
578
+ format for efficiency.
579
+ The ode_backend parameter defaults to 'jaxode' (which uses ``jax.experimental.ode.odeint`` with a default solver
580
+ of 'Dopri5'); if set to 'diffrax', it uses ``diffrax.diffeqsolve`` instead (with a default solver of 'Tsit5').
581
+
582
+ Note: This function currently only supports the JAX backend.
583
+
584
+ :param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
585
+ The function signature should be ``hamiltonian(time, *args) -> Tensor``.
586
+ :type hamiltonian: Callable[..., Tensor]
587
+ :param initial_state: The initial quantum state vector.
588
+ :type initial_state: Tensor
589
+ :param times: Time points for which to compute the evolution. Should be a 1D array of times.
590
+ :type times: Tensor
591
+ :param callback: Optional function to apply to the state at each time step.
592
+ :type callback: Optional[Callable[..., Tensor]]
593
+ :param args: Additional arguments to pass to the Hamiltonian function.
594
+ :type args: tuple | list
595
+ :param solver_kws: Additional keyword arguments to pass to the ODE solver.
596
+
597
+ - ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
598
+ uses ``diffrax.diffeqsolve``.
599
+
600
+ - ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
601
+ like the numerical approximation to your equation.
602
+
603
+ - The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
604
+ and only works when ``ode_backend='diffrax'``.
605
+
606
+ - ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.
607
+
608
+ - ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
609
+ unconditionally and only works when ``ode_backend='diffrax'``.
610
+ :type solver_kws: dict
611
+
612
+ :return: Evolved quantum states at the specified time points. If callback is provided,
613
+ returns the callback results; otherwise returns the state vectors.
614
+ :rtype: Tensor
615
+ """
616
+
617
+ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
618
+ h = -1.0j * hamiltonian(t, *args)
619
+ return h @ y
620
+
621
+ s1 = _solve_ode(f, initial_state, times, args, solver_kws)
622
+
623
+ if callback is None:
624
+ return s1
625
+ return backend.stack([callback(a_state) for a_state in s1])
626
+
627
+
628
+ @partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
629
+ def evol_local(
630
+ c: Circuit,
631
+ index: Sequence[int],
632
+ h_fun: Callable[..., Tensor],
633
+ t: float,
634
+ *args: Any,
635
+ **solver_kws: Any,
636
+ ) -> Circuit:
637
+ """
638
+ ode evolution of time dependent Hamiltonian on circuit of given indices
639
+ [only jax backend support for now]
640
+
641
+ :param c: _description_
642
+ :type c: Circuit
643
+ :param index: qubit sites to evolve
644
+ :type index: Sequence[int]
645
+ :param h_fun: h_fun should return a dense Hamiltonian matrix
646
+ with input arguments ``time`` and ``*args``
647
+ :type h_fun: Callable[..., Tensor]
648
+ :param t: evolution time
649
+ :type t: float
650
+ :return: _description_
651
+ :rtype: Circuit
652
+ """
653
+ s = c.state()
654
+ n = int(np.log2(s.shape[-1]) + 1e-7)
655
+ if isinstance(t, float):
656
+ t = backend.stack([0.0, t])
657
+ s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws)
658
+ return type(c)(n, inputs=s1[-1])
659
+
660
+
661
+ @partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
662
+ def evol_global(
663
+ c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
664
+ ) -> Circuit:
665
+ """
666
+ ode evolution of time dependent Hamiltonian on circuit of all qubits
667
+ [only jax backend support for now]
668
+
669
+ :param c: _description_
670
+ :type c: Circuit
671
+ :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
672
+ with input arguments ``time`` and ``*args``
673
+ :type h_fun: Callable[..., Tensor]
674
+ :param t: _description_
675
+ :type t: float
676
+ :return: _description_
677
+ :rtype: Circuit
678
+ """
679
+ s = c.state()
680
+ n = c._nqubits
681
+ if isinstance(t, float):
682
+ t = backend.stack([0.0, t])
683
+ s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws)
684
+ return type(c)(n, inputs=s1[-1])
685
+
686
+
687
+ def chebyshev_evol(
688
+ hamiltonian: Any,
689
+ initial_state: Tensor,
690
+ t: float,
691
+ spectral_bounds: Tuple[float, float],
692
+ k: int,
693
+ M: int,
694
+ ) -> Any:
695
+ """
696
+ Chebyshev evolution method by expanding the time evolution exponential operator
697
+ in Chebyshev series.
698
+ Note the state returned is not normalized. But the norm should be very close to 1 for
699
+ sufficiently large k and M, which can serve as a accuracy check of the final result.
700
+
701
+ :param hamiltonian: Hamiltonian matrix (sparse or dense)
702
+ :type hamiltonian: Any
703
+ :param initial_state: Initial state vector
704
+ :type initial_state: Tensor
705
+ :param time: Time to evolve
706
+ :type time: float
707
+ :param spectral_bounds: Spectral bounds for the Hamiltonian (Emax, Emin)
708
+ :type spectral_bounds: Tuple[float, float]
709
+ :param k: Number of Chebyshev coefficients, a good estimate is k > t*(Emax-Emin)/2
710
+ :type k: int
711
+ :param M: Number of iterations to estimate Bessel function, a good estimate is given
712
+ by `estimate_M` helper method.
713
+ :type M: int
714
+ :return: Evolved state
715
+ :rtype: Tensor
716
+ """
717
+ # TODO(@refraction-ray): no support for tf backend as bessel function has no implementation
718
+ E_max, E_min = spectral_bounds
719
+ if E_max <= E_min:
720
+ raise ValueError("E_max must be > E_min.")
721
+
722
+ a = (E_max - E_min) / 2.0
723
+ b = (E_max + E_min) / 2.0
724
+ tau = a * t # Rescaled time parameter
725
+
726
+ if backend.is_sparse(hamiltonian):
727
+ hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
728
+
729
+ def apply_h_norm(psi: Any) -> Any:
730
+ """Applies the normalized Hamiltonian to a state."""
731
+ return ((hamiltonian @ psi) - b * psi) / a
732
+
733
+ # Handle edge case where no evolution is needed.
734
+ if k == 0:
735
+ # The phase factor still applies even for zero evolution of the series part.
736
+ phase = backend.exp(-1j * b * t)
737
+ return phase * backend.zeros_like(initial_state)
738
+
739
+ # --- 2. Calculate Chebyshev Expansion Coefficients ---
740
+ k_indices = backend.arange(k)
741
+ bessel_vals = backend.special_jv(k, tau, M)
742
+
743
+ # Prefactor is 1 for k=0 and 2 for k>0.
744
+ prefactor = backend.ones([k])
745
+ if k > 1:
746
+ # Using concat for backend compatibility (vs. jax's .at[1:].set(2.0))
747
+ prefactor = backend.concat(
748
+ [backend.ones([1]), backend.ones([k - 1]) * 2.0], axis=0
749
+ )
750
+
751
+ ik_powers = backend.power(0 - 1j, k_indices)
752
+ coeffs = prefactor * ik_powers * bessel_vals
753
+
754
+ # --- 3. Iteratively build the result using a scan ---
755
+
756
+ # Handle the simple case of k=1 separately.
757
+ if k == 1:
758
+ psi_unphased = coeffs[0] * initial_state
759
+ else: # k >= 2, use the scan operation.
760
+ # Initialize the first two Chebyshev vectors and the initial sum.
761
+ T0 = initial_state
762
+ T1 = apply_h_norm(T0)
763
+ initial_sum = coeffs[0] * T0 + coeffs[1] * T1
764
+
765
+ # The carry for the scan holds the state needed for the next iteration:
766
+ # (current vector T_k, previous vector T_{k-1}, and the running sum).
767
+ initial_carry = (T1, T0, initial_sum)
768
+
769
+ def scan_body(carry, i): # type: ignore
770
+ """The body of the scan operation."""
771
+ Tk, Tkm1, current_sum = carry
772
+
773
+ # Calculate the next Chebyshev vector using the recurrence relation.
774
+ Tkp1 = 2 * apply_h_norm(Tk) - Tkm1
775
+
776
+ # Add its contribution to the running sum.
777
+ new_sum = current_sum + coeffs[i] * Tkp1
778
+
779
+ # Return the updated carry for the next step. No intermediate output is needed.
780
+ return (Tkp1, Tk, new_sum)
781
+
782
+ # Run the scan over the remaining coefficients (from index 2 to k-1).
783
+ final_carry = backend.scan(scan_body, backend.arange(2, k), initial_carry)
784
+
785
+ # The final result is the sum accumulated in the last carry state.
786
+ psi_unphased = final_carry[2]
787
+
788
+ # --- 4. Final Step: Apply Phase Correction ---
789
+ # This undoes the energy shift from the Hamiltonian normalization.
790
+ phase = backend.exp(-1j * b * t)
791
+ psi_final = phase * psi_unphased
792
+
793
+ return psi_final
794
+
795
+
796
+ def estimate_k(t: float, spectral_bounds: Tuple[float, float]) -> int:
797
+ """
798
+ estimate k for chebyshev expansion
799
+
800
+ :param t: time
801
+ :type t: float
802
+ :param spectral_bounds: spectral bounds (Emax, Emin)
803
+ :type spectral_bounds: Tuple[float, float]
804
+ :return: k
805
+ :rtype: int
806
+ """
807
+ E_max, E_min = spectral_bounds
808
+ a = (E_max - E_min) / 2.0
809
+ tau = a * t # tau is now a scalar
810
+ return max(int(1.1 * tau), int(tau + 20))
811
+
812
+
813
+ def estimate_M(t: float, spectral_bounds: Tuple[float, float], k: int) -> int:
814
+ """
815
+ estimate M for Bessel function iterations
816
+
817
+ :param t: time
818
+ :type t: float
819
+ :param spectral_bounds: spectral bounds (Emax, Emin)
820
+ :type spectral_bounds: Tuple[float, float]
821
+ :param k: k
822
+ :type k: int
823
+ :return: M
824
+ :rtype: int
825
+ """
826
+ E_max, E_min = spectral_bounds
827
+ a = (E_max - E_min) / 2.0
828
+ tau = a * t # tau is now a scalar
829
+ safety_factor = 15
830
+ M = max(k, int(abs(tau))) + int(safety_factor * np.sqrt(abs(tau)))
831
+ M = max(M, k + 30)
832
+ return M
833
+
834
+
835
+ def estimate_spectral_bounds(
836
+ h: Any, n_iter: int = 30, psi0: Optional[Any] = None
837
+ ) -> Tuple[float, float]:
838
+ """
839
+ Lanczos algorithm to estimate the spectral bounds of a Hamiltonian.
840
+ Just for quick run before `chebyshev_evol`, non jit-able.
841
+
842
+ :param h: Hamiltonian matrix.
843
+ :type h: Any
844
+ :param n_iter: iteration number.
845
+ :type n_iter: int
846
+ :param psi0: Optional initial state.
847
+ :type psi0: Optional[Any]
848
+ :return: (E_max, E_min)
849
+ """
850
+ shape = h.shape
851
+ D = shape[-1]
852
+ if psi0 is None:
853
+ psi0 = np.random.normal(size=[D])
854
+
855
+ psi0 = backend.convert_to_tensor(psi0) / backend.norm(psi0)
856
+ psi0 = backend.cast(psi0, dtypestr)
857
+
858
+ # Lanczos
859
+ alphas = []
860
+ betas = []
861
+ q_prev = backend.zeros(psi0.shape, dtype=psi0.dtype)
862
+ q = psi0
863
+ beta = 0
864
+
865
+ for _ in range(n_iter):
866
+ r = h @ q
867
+ r = backend.convert_to_tensor(r) # in case np.matrix
868
+ r = backend.reshape(r, [-1])
869
+ if beta != 0:
870
+ r -= backend.cast(beta, dtypestr) * q_prev
871
+
872
+ alpha = backend.real(backend.sum(backend.conj(q) * r))
873
+
874
+ alphas.append(alpha)
875
+
876
+ r -= backend.cast(alpha, dtypestr) * q
877
+
878
+ q_prev = q
879
+ beta = backend.norm(r)
880
+ q = r / beta
881
+ beta = backend.abs(beta)
882
+ betas.append(beta)
883
+ if beta < 1e-8:
884
+ break
885
+
886
+ alphas = backend.stack(alphas)
887
+ betas = backend.stack(betas)
888
+ T = (
889
+ backend.diagflat(alphas)
890
+ + backend.diagflat(betas[:-1], k=1)
891
+ + backend.diagflat(betas[:-1], k=-1)
892
+ )
893
+
894
+ ritz_values, _ = backend.eigh(T)
895
+
896
+ return backend.max(ritz_values), backend.min(ritz_values)