Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/trajectory.py +2 -2
  7. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
  8. trajectree-0.0.2.dist-info/RECORD +16 -0
  9. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  10. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  11. trajectree/quimb/docs/conf.py +0 -158
  12. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  13. trajectree/quimb/quimb/__init__.py +0 -507
  14. trajectree/quimb/quimb/calc.py +0 -1491
  15. trajectree/quimb/quimb/core.py +0 -2279
  16. trajectree/quimb/quimb/evo.py +0 -712
  17. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  18. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  19. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  20. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  21. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  22. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  23. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  24. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  25. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  26. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  27. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  28. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  29. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  30. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  31. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  32. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  33. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  34. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  35. trajectree/quimb/quimb/gates.py +0 -36
  36. trajectree/quimb/quimb/gen/__init__.py +0 -2
  37. trajectree/quimb/quimb/gen/operators.py +0 -1167
  38. trajectree/quimb/quimb/gen/rand.py +0 -713
  39. trajectree/quimb/quimb/gen/states.py +0 -479
  40. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  41. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  42. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  43. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  44. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  45. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  46. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  47. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  48. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  49. trajectree/quimb/quimb/schematic.py +0 -1518
  50. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  51. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  52. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  53. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  54. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  55. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  56. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  57. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  58. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  59. trajectree/quimb/quimb/tensor/interface.py +0 -114
  60. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  61. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  62. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  63. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  64. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  65. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  66. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  67. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  68. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  69. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  70. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  71. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  72. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  74. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  75. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  76. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  77. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  78. trajectree/quimb/quimb/utils.py +0 -892
  79. trajectree/quimb/tests/__init__.py +0 -0
  80. trajectree/quimb/tests/test_accel.py +0 -501
  81. trajectree/quimb/tests/test_calc.py +0 -788
  82. trajectree/quimb/tests/test_core.py +0 -847
  83. trajectree/quimb/tests/test_evo.py +0 -565
  84. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  85. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  86. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  87. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  88. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  89. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  90. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  91. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  92. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  93. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  94. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  95. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  103. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  104. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  105. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  106. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  107. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  108. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  109. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  110. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  111. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  112. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  113. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  114. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  115. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  116. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  117. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  118. trajectree/quimb/tests/test_utils.py +0 -85
  119. trajectree-0.0.1.dist-info/RECORD +0 -126
  120. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,1109 +0,0 @@
1
- """Use stochastic Lanczos quadrature to approximate spectral function sums of
2
- any operator which has an efficient representation of action on a vector.
3
- """
4
-
5
- import functools
6
- import random
7
- import warnings
8
- from math import exp, inf, log2, nan, sqrt
9
-
10
- import numpy as np
11
- import scipy.linalg as scla
12
- from scipy.ndimage import uniform_filter1d
13
-
14
- from ..core import divide_update_, dot, njit, prod, ptr, subtract_update_, vdot
15
- from ..gen.rand import rand_phase, rand_rademacher, randn, seed_rand
16
- from ..linalg.mpi_launcher import get_mpi_pool
17
- from ..utils import (
18
- default_to_neutral_style,
19
- find_library,
20
- format_number_with_error,
21
- int2tup,
22
- raise_cant_find_library_function,
23
- )
24
- from ..utils import progbar as Progbar
25
-
26
- if find_library("cotengra") and find_library("autoray"):
27
- from ..tensor.tensor_1d import MatrixProductOperator
28
- from ..tensor.tensor_approx_spectral import construct_lanczos_tridiag_MPO
29
- from ..tensor.tensor_core import Tensor
30
- else:
31
- reqs = "[cotengra,autoray]"
32
- Tensor = raise_cant_find_library_function(reqs)
33
- construct_lanczos_tridiag_MPO = raise_cant_find_library_function(reqs)
34
-
35
-
36
- # --------------------------------------------------------------------------- #
37
- # 'Lazy' representation tensor contractions #
38
- # --------------------------------------------------------------------------- #
39
-
40
-
41
- def lazy_ptr_linop(psi_ab, dims, sysa, **linop_opts):
42
- r"""A linear operator representing action of partially tracing a bipartite
43
- state, then multiplying another 'unipartite' state::
44
-
45
- ( | )
46
- +-------+
47
- | psi_a | ______
48
- +_______+ / \
49
- a| |b |
50
- +-------------+ |
51
- | psi_ab.H | |
52
- +_____________+ |
53
- |
54
- +-------------+ |
55
- | psi_ab | |
56
- +_____________+ |
57
- a| |b |
58
- | \______/
59
-
60
- Parameters
61
- ----------
62
- psi_ab : ket
63
- State to partially trace and dot with another ket, with
64
- size ``prod(dims)``.
65
- dims : sequence of int, optional
66
- The sub dimensions of ``psi_ab``.
67
- sysa : int or sequence of int, optional
68
- Index(es) of the 'a' subsystem(s) to keep.
69
- """
70
- sysa = int2tup(sysa)
71
-
72
- Kab = Tensor(
73
- np.asarray(psi_ab).reshape(dims),
74
- inds=[
75
- ("kA{}" if i in sysa else "xB{}").format(i)
76
- for i in range(len(dims))
77
- ],
78
- )
79
-
80
- Bab = Tensor(
81
- Kab.data.conjugate(),
82
- inds=[
83
- ("bA{}" if i in sysa else "xB{}").format(i)
84
- for i in range(len(dims))
85
- ],
86
- )
87
-
88
- return (Kab & Bab).aslinearoperator(
89
- [f"kA{i}" for i in sysa], [f"bA{i}" for i in sysa], **linop_opts
90
- )
91
-
92
-
93
- def lazy_ptr_ppt_linop(psi_abc, dims, sysa, sysb, **linop_opts):
94
- r"""A linear operator representing action of partially tracing a tripartite
95
- state, partially transposing the remaining bipartite state, then
96
- multiplying another bipartite state::
97
-
98
- ( | )
99
- +--------------+
100
- | psi_ab |
101
- +______________+ _____
102
- a| ____ b| / \
103
- | / a\ | |c |
104
- | | +-------------+ |
105
- | | | psi_abc.H | |
106
- \ / +-------------+ |
107
- X |
108
- / \ +-------------+ |
109
- | | | psi_abc | |
110
- | | +-------------+ |
111
- | \____/a |b |c |
112
- a| | \_____/
113
-
114
- Parameters
115
- ----------
116
- psi_abc : ket
117
- State to partially trace, partially transpose, then dot with another
118
- ket, with size ``prod(dims)``.
119
- ``prod(dims[sysa] + dims[sysb])``.
120
- dims : sequence of int
121
- The sub dimensions of ``psi_abc``.
122
- sysa : int or sequence of int, optional
123
- Index(es) of the 'a' subsystem(s) to keep, with respect to all
124
- the dimensions, ``dims``, (i.e. pre-partial trace).
125
- sysa : int or sequence of int, optional
126
- Index(es) of the 'b' subsystem(s) to keep, with respect to all
127
- the dimensions, ``dims``, (i.e. pre-partial trace).
128
- """
129
- sysa, sysb = int2tup(sysa), int2tup(sysb)
130
- sys_ab = sorted(sysa + sysb)
131
-
132
- Kabc = Tensor(
133
- np.asarray(psi_abc).reshape(dims),
134
- inds=[
135
- ("kA{}" if i in sysa else "kB{}" if i in sysb else "xC{}").format(
136
- i
137
- )
138
- for i in range(len(dims))
139
- ],
140
- )
141
-
142
- Babc = Tensor(
143
- Kabc.data.conjugate(),
144
- inds=[
145
- ("bA{}" if i in sysa else "bB{}" if i in sysb else "xC{}").format(
146
- i
147
- )
148
- for i in range(len(dims))
149
- ],
150
- )
151
-
152
- return (Kabc & Babc).aslinearoperator(
153
- [("bA{}" if i in sysa else "kB{}").format(i) for i in sys_ab],
154
- [("kA{}" if i in sysa else "bB{}").format(i) for i in sys_ab],
155
- **linop_opts,
156
- )
157
-
158
-
159
- # --------------------------------------------------------------------------- #
160
- # Lanczos tri-diag technique #
161
- # --------------------------------------------------------------------------- #
162
-
163
-
164
- def inner(a, b):
165
- """Inner product between two vectors"""
166
- return vdot(a, b).real
167
-
168
-
169
- def norm_fro(a):
170
- """'Frobenius' norm of a vector."""
171
- return sqrt(inner(a, a))
172
-
173
-
174
- def norm_fro_approx(A, **kwargs):
175
- r"""Calculate the approximate frobenius norm of any hermitian linear
176
- operator:
177
-
178
- .. math::
179
-
180
- \mathrm{Tr} \left[ A^{\dagger} A \right]
181
-
182
- Parameters
183
- ----------
184
- A : linear operator like
185
- Operator with a dot method, assumed to be hermitian, to estimate the
186
- frobenius norm of.
187
- kwargs
188
- Supplied to :func:`approx_spectral_function`.
189
-
190
- Returns
191
- -------
192
- float
193
- """
194
- return approx_spectral_function(A, lambda x: x**2, **kwargs) ** 0.5
195
-
196
-
197
- def random_rect(
198
- shape,
199
- dist="rademacher",
200
- orthog=False,
201
- norm=True,
202
- seed=False,
203
- dtype=complex,
204
- ):
205
- """Generate a random array optionally orthogonal.
206
-
207
- Parameters
208
- ----------
209
- shape : tuple of int
210
- The shape of array.
211
- dist : {'guassian', 'rademacher'}
212
- Distribution of the random variables.
213
- orthog : bool or operator.
214
- Orthogonalize the columns if more than one.
215
- norm : bool
216
- Explicitly normalize the frobenius norm to 1.
217
- """
218
- if seed:
219
- # needs to be truly random so e.g. MPI processes don't overlap
220
- seed_rand(random.SystemRandom().randint(0, 2**32 - 1))
221
-
222
- if dist == "rademacher":
223
- V = rand_rademacher(shape, scale=1 / sqrt(prod(shape)), dtype=dtype)
224
- # already normalized
225
-
226
- elif dist == "gaussian":
227
- V = randn(shape, scale=1 / (prod(shape) ** 0.5 * 2**0.5), dtype=dtype)
228
- if norm:
229
- V /= norm_fro(V)
230
-
231
- elif dist == "phase":
232
- V = rand_phase(shape, scale=1 / sqrt(prod(shape)), dtype=dtype)
233
- # already normalized
234
-
235
- else:
236
- raise ValueError(f"`dist={dist}` not understood.")
237
-
238
- if orthog and min(shape) > 1:
239
- V = scla.orth(V)
240
- V /= sqrt(min(V.shape))
241
-
242
- return V
243
-
244
-
245
- def construct_lanczos_tridiag(
246
- A,
247
- K,
248
- v0=None,
249
- bsz=1,
250
- k_min=10,
251
- orthog=False,
252
- beta_tol=1e-6,
253
- seed=False,
254
- v0_opts=None,
255
- ):
256
- """Construct the tridiagonal lanczos matrix using only matvec operators.
257
- This is a generator that iteratively yields the alpha and beta digaonals
258
- at each step.
259
-
260
- Parameters
261
- ----------
262
- A : dense array, sparse matrix or linear operator
263
- The operator to approximate, must implement ``.dot`` method to compute
264
- its action on a vector.
265
- K : int, optional
266
- The maximum number of iterations and thus rank of the matrix to find.
267
- v0 : vector, optional
268
- The starting vector to iterate with, default to random.
269
- bsz : int, optional
270
- The block size (number of columns) of random vectors to iterate with.
271
- k_min : int, optional
272
- The minimum size of the krylov subspace for form.
273
- orthog : bool, optional
274
- If True, perform full re-orthogonalization for each new vector.
275
- beta_tol : float, optional
276
- The 'breakdown' tolerance. If the next beta ceofficient in the lanczos
277
- matrix is less that this, implying that the full non-null space has
278
- been found, terminate early.
279
- seed : bool, optional
280
- If True, seed the numpy random generator with a system random int.
281
-
282
- Yields
283
- ------
284
- alpha : sequence of float of length k
285
- The diagonal entries of the lanczos matrix.
286
- beta : sequence of float of length k
287
- The off-diagonal entries of the lanczos matrix, with the last entry
288
- the 'look' forward value.
289
- scaling : float
290
- How to scale the overall weights.
291
- """
292
- d = A.shape[0]
293
-
294
- if bsz == 1:
295
- v_shp = (d,)
296
- else:
297
- orthog = False
298
- v_shp = (d, bsz)
299
-
300
- alpha = np.zeros(K + 1, dtype=get_equivalent_real_dtype(A.dtype))
301
- beta = np.zeros(K + 2, dtype=get_equivalent_real_dtype(A.dtype))
302
- beta[1] = sqrt(prod(v_shp)) # by construction
303
-
304
- if v0 is None:
305
- if v0_opts is None:
306
- v0_opts = {}
307
- q = random_rect(v_shp, seed=seed, dtype=A.dtype, **v0_opts)
308
- else:
309
- q = v0.astype(A.dtype)
310
- divide_update_(q, norm_fro(q), q)
311
- v = np.zeros_like(q)
312
-
313
- if orthog:
314
- Q = np.copy(q).reshape(-1, 1)
315
-
316
- for j in range(1, K + 1):
317
- r = dot(A, q)
318
- subtract_update_(r, beta[j], v)
319
- alpha[j] = inner(q, r)
320
- subtract_update_(r, alpha[j], q)
321
-
322
- # perform full orthogonalization
323
- if orthog:
324
- r -= Q.dot(Q.conj().T.dot(r))
325
-
326
- beta[j + 1] = norm_fro(r)
327
-
328
- # check for convergence
329
- if abs(beta[j + 1]) < beta_tol:
330
- yield (
331
- alpha[1 : j + 1].copy(),
332
- beta[2 : j + 2].copy(),
333
- beta[1] ** 2 / bsz,
334
- )
335
- break
336
-
337
- v[()] = q
338
- divide_update_(r, beta[j + 1], q)
339
-
340
- # keep all vectors
341
- if orthog:
342
- Q = np.concatenate((Q, q.reshape(-1, 1)), axis=1)
343
-
344
- if j >= k_min:
345
- yield (
346
- alpha[1 : j + 1].copy(),
347
- beta[2 : j + 2].copy(),
348
- beta[1] ** 2 / bsz,
349
- )
350
-
351
-
352
- def lanczos_tridiag_eig(alpha, beta, check_finite=True):
353
- """Find the eigen-values and -vectors of the Lanczos triadiagonal matrix.
354
-
355
- Parameters
356
- ----------
357
- alpha : array of float
358
- The diagonal.
359
- beta : array of float
360
- The k={-1, 1} off-diagonal. Only first ``len(alpha) - 1`` entries used.
361
- """
362
- Tk_banded = np.empty((2, alpha.size), dtype=alpha.dtype)
363
- Tk_banded[1, -1] = 0.0 # sometimes can get nan here? -> breaks eig_banded
364
- Tk_banded[0, :] = alpha
365
- Tk_banded[1, : beta.size] = beta
366
-
367
- try:
368
- tl, tv = scla.eig_banded(
369
- Tk_banded, lower=True, check_finite=check_finite
370
- )
371
-
372
- # sometimes get no convergence -> use dense hermitian method
373
- except scla.LinAlgError: # pragma: no cover
374
- tl, tv = np.linalg.eigh(
375
- np.diag(alpha) + np.diag(beta[: alpha.size - 1], -1), UPLO="L"
376
- )
377
-
378
- return tl, tv
379
-
380
-
381
- def calc_trace_fn_tridiag(tl, tv, f, pos=True):
382
- """Spectral ritz function sum, weighted by ritz vectors."""
383
- return sum(
384
- tv[0, i] ** 2 * f(max(tl[i], 0.0) if pos else tl[i])
385
- for i in range(tl.size)
386
- )
387
-
388
-
389
- @njit
390
- def ext_per_trim(x, p=0.6, s=1.0): # pragma: no cover
391
- r"""Extended percentile trimmed-mean. Makes the mean robust to asymmetric
392
- outliers, while using all data when it is nicely clustered. This can be
393
- visualized roughly as::
394
-
395
- |--------|=========|--------|
396
- x x xx xx xxxxx xxx xx x x x
397
-
398
- Where the inner range contains the central ``p`` proportion of the data,
399
- and the outer ranges entends this by a factor of ``s`` either side.
400
-
401
- Parameters
402
- ----------
403
- x : array
404
- Data to trim.
405
- p : Proportion of data used to define the 'central' percentile.
406
- For example, p=0.5 gives the inter-quartile range.
407
- s : Include data up to this factor times the central 'percentile' range
408
- away from the central percentile itself.
409
-
410
- Returns
411
- xt : array
412
- Trimmed data.
413
- """
414
- lb = np.percentile(x, 100 * (1 - p) / 2)
415
- ub = np.percentile(x, 100 * (1 + p) / 2)
416
- ib = ub - lb
417
-
418
- trimmed_x = x[(lb - s * ib < x) & (x < ub + s * ib)]
419
-
420
- return trimmed_x
421
-
422
-
423
- @njit # pragma: no cover
424
- def nbsum(xs):
425
- tot = 0
426
- for x in xs:
427
- tot += x
428
- return tot
429
-
430
-
431
- @njit # pragma: no cover
432
- def std(xs):
433
- """Simple standard deviation - don't invoke numpy for small lists."""
434
- N = len(xs)
435
- xm = nbsum(xs) / N
436
- var = nbsum([(x - xm) ** 2 for x in xs]) / N
437
- return var**0.5
438
-
439
-
440
- def calc_est_fit(estimates, conv_n, tau):
441
- """Make estimate by fitting exponential convergence to estimates."""
442
- n = len(estimates)
443
-
444
- if n < conv_n:
445
- return nan, inf
446
-
447
- # iteration number, fit function to inverse this to get k->infinity
448
- ks = np.arange(1, len(estimates) + 1)
449
-
450
- # smooth data with a running mean
451
- smoothed_estimates = uniform_filter1d(estimates, n // 2)
452
-
453
- # ignore this amount of the initial estimates and fit later part only
454
- ni = n // 2
455
-
456
- try:
457
- with warnings.catch_warnings():
458
- warnings.simplefilter("ignore")
459
-
460
- # fit the inverse data with a line, weighting recent ests more
461
- popt, pcov = np.polyfit(
462
- x=(1 / ks[ni:]),
463
- y=smoothed_estimates[ni:],
464
- w=ks[ni:],
465
- deg=1,
466
- cov=True,
467
- )
468
-
469
- # estimate of function at 1 / k = 0 and standard error
470
- est, err = popt[-1], abs(pcov[-1, -1]) ** 0.5
471
-
472
- except (ValueError, RuntimeError):
473
- est, err = nan, inf
474
-
475
- return est, err
476
-
477
-
478
- def calc_est_window(estimates, conv_n):
479
- """Make estimate from mean of last ``m`` samples, following:
480
-
481
- 1. Take between ``conv_n`` and 12 estimates.
482
- 2. Pair the estimates as they are alternate upper/lower bounds
483
- 3. Compute the standard error on the paired estimates.
484
- """
485
- m_est = min(max(conv_n, len(estimates) // 8), 12)
486
- est = sum(estimates[-m_est:]) / len(estimates[-m_est:])
487
-
488
- if len(estimates) > conv_n:
489
- # check for convergence using variance of paired last m estimates
490
- # -> paired because estimates alternate between upper and lower bound
491
- paired_ests = tuple(
492
- (a + b) / 2
493
- for a, b in zip(estimates[-m_est::2], estimates[-m_est + 1 :: 2])
494
- )
495
- err = std(paired_ests) / (m_est / 2) ** 0.5
496
- else:
497
- err = inf
498
-
499
- return est, err
500
-
501
-
502
- def single_random_estimate(
503
- A,
504
- K,
505
- bsz,
506
- beta_tol,
507
- v0,
508
- f,
509
- pos,
510
- tau,
511
- tol_scale,
512
- k_min=10,
513
- verbosity=0,
514
- *,
515
- seed=None,
516
- v0_opts=None,
517
- info=None,
518
- **lanczos_opts,
519
- ):
520
- # choose normal (any LinearOperator) or MPO lanczos tridiag construction
521
- if isinstance(A, MatrixProductOperator):
522
- lanc_fn = construct_lanczos_tridiag_MPO
523
- else:
524
- lanc_fn = construct_lanczos_tridiag
525
- lanczos_opts["bsz"] = bsz
526
-
527
- estimates_raw = []
528
- estimates_window = []
529
- estimates_fit = []
530
- estimates = []
531
-
532
- # the number of samples to check standard deviation convergence with
533
- conv_n = 6 # 3 pairs
534
-
535
- # iteratively build the lanczos matrix, checking for convergence
536
- for alpha, beta, scaling in lanc_fn(
537
- A,
538
- K=K,
539
- beta_tol=beta_tol,
540
- seed=seed,
541
- k_min=k_min - 2 * conv_n,
542
- v0=v0() if callable(v0) else v0,
543
- v0_opts=v0_opts,
544
- **lanczos_opts,
545
- ):
546
- try:
547
- Tl, Tv = lanczos_tridiag_eig(alpha, beta, check_finite=False)
548
- Gf = scaling * calc_trace_fn_tridiag(Tl, Tv, f=f, pos=pos)
549
- except scla.LinAlgError: # pragma: no cover
550
- warnings.warn("Approx Spectral Gf tri-eig didn't converge.")
551
- estimates_raw.append(np.nan)
552
- continue
553
-
554
- k = alpha.size
555
- estimates_raw.append(Gf)
556
-
557
- # check for break-down convergence (e.g. found entire subspace)
558
- # in which case latest estimate should be accurate
559
- if abs(beta[-1]) < beta_tol:
560
- if verbosity >= 2:
561
- print(f"k={k}: Beta breadown, returning {Gf}.")
562
- est = Gf
563
- estimates.append(est)
564
- break
565
-
566
- # compute an estimate and error using a window of the last few results
567
- win_est, win_err = calc_est_window(estimates_raw, conv_n)
568
- estimates_window.append(win_est)
569
-
570
- # try and compute an estimate and error using exponential fit
571
- fit_est, fit_err = calc_est_fit(estimates_window, conv_n, tau)
572
- estimates_fit.append(fit_est)
573
-
574
- # take whichever has lowest error
575
- est, err = min(
576
- (win_est, win_err),
577
- (fit_est, fit_err),
578
- key=lambda est_err: est_err[1],
579
- )
580
- estimates.append(est)
581
- converged = err < tau * (abs(win_est) + tol_scale)
582
-
583
- if verbosity >= 2:
584
- if verbosity >= 3:
585
- print(f"est_win={win_est}, err_win={win_err}")
586
- print(f"est_fit={fit_est}, err_fit={fit_err}")
587
- print(f"k={k}: Gf={Gf}, Est={est}, Err={err}")
588
- if converged:
589
- print(f"k={k}: Converged to tau {tau}.")
590
-
591
- if converged:
592
- break
593
-
594
- if verbosity >= 1:
595
- print(f"k={k}: Returning estimate {est}.")
596
-
597
- if info is not None:
598
- if "estimates_raw" in info:
599
- info["estimates_raw"].append(estimates_raw)
600
- if "estimates_window" in info:
601
- info["estimates_window"].append(estimates_window)
602
- if "estimates_fit" in info:
603
- info["estimates_fit"].append(estimates_fit)
604
- if "estimates" in info:
605
- info["estimates"].append(estimates)
606
-
607
- return est
608
-
609
-
610
- def calc_stats(samples, mean_p, mean_s, tol, tol_scale):
611
- """Get an estimate from samples."""
612
- samples = np.array(samples)
613
-
614
- xtrim = ext_per_trim(samples, p=mean_p, s=mean_s)
615
-
616
- # sometimes everything is an outlier...
617
- if xtrim.size == 0: # pragma: no cover
618
- estimate, sdev = np.mean(samples), std(samples)
619
- else:
620
- estimate, sdev = np.mean(xtrim), std(xtrim)
621
-
622
- err = sdev / len(samples) ** 0.5
623
-
624
- converged = err < tol * (abs(estimate) + tol_scale)
625
-
626
- return estimate, err, converged
627
-
628
-
629
- def get_single_precision_dtype(dtype):
630
- if np.issubdtype(dtype, np.complexfloating):
631
- return np.complex64
632
- elif np.issubdtype(dtype, np.floating):
633
- return np.float32
634
- else:
635
- raise ValueError(f"dtype {dtype} not understood.")
636
-
637
-
638
- def get_equivalent_real_dtype(dtype):
639
- if dtype in ("float64", "complex128"):
640
- return "float64"
641
- elif dtype in ("float32", "complex64"):
642
- return "float32"
643
- else:
644
- raise ValueError(f"dtype {dtype} not understood.")
645
-
646
-
647
- @default_to_neutral_style
648
- def plot_approx_spectral_info(info):
649
- from matplotlib import pyplot as plt
650
- from matplotlib.ticker import MaxNLocator
651
-
652
- fig, axs = plt.subplots(
653
- ncols=2,
654
- figsize=(8, 4),
655
- sharey=True,
656
- gridspec_kw={"width_ratios": [3, 1]},
657
- )
658
- plt.subplots_adjust(wspace=0.0)
659
-
660
- Z = info["estimate"]
661
-
662
- alpha = len(info["estimates_raw"])**-(1 / 6)
663
-
664
- # plot the raw kyrlov runs
665
- for x in info["estimates_raw"]:
666
- axs[0].plot(x, ".-", alpha=alpha, lw=1 / 2, zorder=-10, markersize=1)
667
- axs[0].axhline(Z - info["error"], color="grey", linestyle="--")
668
- axs[0].axhline(Z + info["error"], color="grey", linestyle="--")
669
- axs[0].axhline(Z, color="black", linestyle="--")
670
- axs[0].set_rasterization_zorder(-5)
671
- axs[0].set_xlabel("krylov iteration (offset)")
672
- axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
673
- axs[0].set_ylabel("$Tr[f(x)]$ approximation")
674
-
675
- # plot the overall final samples
676
- axs[1].hist(
677
- info["samples"],
678
- bins=round(len(info["samples"])**0.5),
679
- orientation="horizontal",
680
- color=(0.2, 0.6, 1.0),
681
- )
682
- axs[1].axhline(Z - info["error"], color="grey", linestyle="--")
683
- axs[1].axhline(Z + info["error"], color="grey", linestyle="--")
684
- axs[1].axhline(Z, color="black", linestyle="--")
685
- axs[1].set_xlabel("sample count")
686
- axs[1].set_title(
687
- "estimate ≈ " + format_number_with_error(Z, info["error"]),
688
- ha="right",
689
- )
690
-
691
- # plot the correlation between raw and fitted estimates
692
- iax = axs[0].inset_axes((0.03, 0.6, 0.3, 0.3))
693
- iax.set_aspect("equal")
694
- x = [es[-1] for es in info["estimates"]]
695
- y = [es[-1] for es in info["estimates_raw"]]
696
- iax.scatter(x, y, marker=".", alpha=alpha, color=(0.3, 0.7, 0.3), s=1)
697
-
698
- return fig, axs
699
-
700
-
701
- def approx_spectral_function(
702
- A,
703
- f,
704
- tol=1e-2,
705
- *,
706
- bsz=1,
707
- R=1024,
708
- R_min=3,
709
- tol_scale=1,
710
- tau=1e-4,
711
- k_min=10,
712
- k_max=512,
713
- beta_tol=1e-6,
714
- mpi=False,
715
- mean_p=0.7,
716
- mean_s=1.0,
717
- pos=False,
718
- v0=None,
719
- verbosity=0,
720
- single_precision="AUTO",
721
- info=None,
722
- progbar=False,
723
- plot=False,
724
- **lanczos_opts,
725
- ):
726
- """Approximate a spectral function, that is, the quantity ``Tr(f(A))``.
727
-
728
- Parameters
729
- ----------
730
- A : dense array, sparse matrix or LinearOperator
731
- Operator to approximate spectral function for. Should implement
732
- ``A.dot(vec)``.
733
- f : callable
734
- Scalar function with which to act on approximate eigenvalues.
735
- tol : float, optional
736
- Relative convergence tolerance threshold for error on mean of repeats.
737
- This can pretty much be relied on as the overall accuracy. See also
738
- ``tol_scale`` and ``tau``. Default: 1%.
739
- bsz : int, optional
740
- Number of simultenous vector columns to use at once, 1 equating to the
741
- standard lanczos method. If ``bsz > 1`` then ``A`` must implement
742
- matrix-matrix multiplication. This is a more performant way of
743
- essentially increasing ``R``, at the cost of more memory. Default: 1.
744
- R : int, optional
745
- The number of repeats with different initial random vectors to perform.
746
- Increasing this should increase accuracy as ``sqrt(R)``. Cost of
747
- algorithm thus scales linearly with ``R``. If ``tol`` is non-zero, this
748
- is the maximum number of repeats.
749
- R_min : int, optional
750
- The minimum number of repeats to perform. Default: 3.
751
- tau : float, optional
752
- The relative tolerance required for a single lanczos run to converge.
753
- This needs to be small enough that each estimate with a single random
754
- vector produces an unbiased sample of the operators spectrum..
755
- k_min : int, optional
756
- The minimum size of the krylov subspace to form for each sample.
757
- k_max : int, optional
758
- The maximum size of the kyrlov space to form. Cost of algorithm scales
759
- linearly with ``K``. If ``tau`` is non-zero, this is the maximum size
760
- matrix to form.
761
- tol_scale : float, optional
762
- This sets the overall expected scale of each estimate, so that an
763
- absolute tolerance can be used for values near zero. Default: 1.
764
- beta_tol : float, optional
765
- The 'breakdown' tolerance. If the next beta ceofficient in the lanczos
766
- matrix is less that this, implying that the full non-null space has
767
- been found, terminate early. Default: 1e-6.
768
- mpi : bool, optional
769
- Whether to parallelize repeat runs over MPI processes.
770
- mean_p : float, optional
771
- Factor for robustly finding mean and err of repeat estimates,
772
- see :func:`ext_per_trim`.
773
- mean_s : float, optional
774
- Factor for robustly finding mean and err of repeat estimates,
775
- see :func:`ext_per_trim`.
776
- v0 : vector, or callable
777
- Initial vector to iterate with, sets ``R=1`` if given. If callable, the
778
- function to produce a random intial vector (sequence).
779
- pos : bool, optional
780
- If True, make sure any approximate eigenvalues are positive by
781
- clipping below 0.
782
- verbosity : {0, 1, 2}, optional
783
- How much information to print while computing.
784
- single_precision : {'AUTO', False, True}, optional
785
- Try and convert the operator to single precision. This can lead to much
786
- faster operation, especially if a GPU is available. Additionally,
787
- double precision is not really needed given the stochastic nature of
788
- the algorithm.
789
- lanczos_opts
790
- Supplied to
791
- :func:`~quimb.linalg.approx_spectral.single_random_estimate` or
792
- :func:`~quimb.linalg.approx_spectral.construct_lanczos_tridiag`.
793
-
794
-
795
- Returns
796
- -------
797
- scalar
798
- The approximate value ``Tr(f(a))``.
799
-
800
- See Also
801
- --------
802
- construct_lanczos_tridiag
803
- """
804
- if single_precision == "AUTO":
805
- single_precision = hasattr(A, "astype")
806
- if single_precision:
807
- A = A.astype(get_single_precision_dtype(A.dtype))
808
-
809
- if (v0 is not None) and not callable(v0):
810
- # we only have one sample to run
811
- R = 1
812
- else:
813
- R = max(1, int(R / bsz))
814
-
815
- if tau is None:
816
- # require better precision for the lanczos procedure, otherwise biased
817
- tau = tol / 1000
818
-
819
- if verbosity:
820
- print(f"LANCZOS f(A) CALC: tol={tol}, tau={tau}, R={R}, bsz={bsz}")
821
-
822
- if plot:
823
- # need to store all the info
824
- if info is None:
825
- info = {}
826
- info.setdefault('estimate', None)
827
- info.setdefault('error', None)
828
- info.setdefault('samples', None)
829
- info.setdefault('estimates_raw', [])
830
- info.setdefault('estimates_window', [])
831
- info.setdefault('estimates_fit', [])
832
- info.setdefault('estimates', [])
833
-
834
- # generate repeat estimates
835
- kwargs = {
836
- "A": A,
837
- "K": k_max,
838
- "bsz": bsz,
839
- "beta_tol": beta_tol,
840
- "v0": v0,
841
- "f": f,
842
- "pos": pos,
843
- "tau": tau,
844
- "k_min": k_min,
845
- "tol_scale": tol_scale,
846
- "verbosity": verbosity,
847
- "info": info,
848
- **lanczos_opts,
849
- }
850
-
851
- if not mpi:
852
-
853
- def gen_results():
854
- for _ in range(R):
855
- yield single_random_estimate(**kwargs)
856
-
857
- else:
858
- pool = get_mpi_pool()
859
- kwargs["seed"] = True
860
- fs = [pool.submit(single_random_estimate, **kwargs) for _ in range(R)]
861
-
862
- def gen_results():
863
- for f in fs:
864
- yield f.result()
865
-
866
- if progbar:
867
- pbar = Progbar(total=R)
868
- else:
869
- pbar = None
870
-
871
- # iterate through estimates, waiting for convergence
872
- results = gen_results()
873
- estimate = None
874
- samples = []
875
- for _ in range(R):
876
- samples.append(next(results))
877
-
878
- if verbosity >= 1:
879
- print(f"Repeat {len(samples)}: estimate is {samples[-1]}")
880
-
881
- # wait a few iterations before checking error on mean breakout
882
- if len(samples) >= R_min:
883
- estimate, err, converged = calc_stats(
884
- samples, mean_p, mean_s, tol, tol_scale
885
- )
886
- if verbosity >= 1:
887
- print(f"Total estimate = {estimate} ± {err}")
888
- if converged:
889
- if verbosity >= 1:
890
- print(f"Repeat {len(samples)}: converged to tol {tol}")
891
- break
892
-
893
- if pbar:
894
- if len(samples) < R_min:
895
- estimate, err, _ = calc_stats(
896
- samples, mean_p, mean_s, tol, tol_scale
897
- )
898
- pbar.set_description(format_number_with_error(estimate, err))
899
-
900
- if pbar:
901
- pbar.update()
902
- if pbar:
903
- pbar.close()
904
-
905
- if mpi:
906
- # deal with remaining futures
907
- extra_futures = []
908
- for f in fs:
909
- if f.done() or f.running():
910
- extra_futures.append(f)
911
- else:
912
- f.cancel()
913
-
914
- if extra_futures:
915
- # might as well combine finished samples
916
- samples.extend(f.result() for f in extra_futures)
917
- estimate, err, converged = calc_stats(
918
- samples, mean_p, mean_s, tol, tol_scale
919
- )
920
-
921
- if estimate is None:
922
- estimate, err, _ = calc_stats(samples, mean_p, mean_s, tol, tol_scale)
923
-
924
- if verbosity >= 1:
925
- print(f"ESTIMATE is {estimate} ± {err}")
926
-
927
- if info is not None:
928
- if "samples" in info:
929
- info["samples"] = samples
930
- if "error" in info:
931
- info["error"] = err
932
- if "estimate" in info:
933
- info["estimate"] = estimate
934
-
935
- if plot:
936
- info["fig"], info["axs"] = plot_approx_spectral_info(info)
937
-
938
- return estimate
939
-
940
-
941
- @functools.wraps(approx_spectral_function)
942
- def tr_abs_approx(*args, **kwargs):
943
- return approx_spectral_function(*args, f=abs, **kwargs)
944
-
945
-
946
- @functools.wraps(approx_spectral_function)
947
- def tr_exp_approx(*args, **kwargs):
948
- return approx_spectral_function(*args, f=exp, **kwargs)
949
-
950
-
951
- @functools.wraps(approx_spectral_function)
952
- def tr_sqrt_approx(*args, **kwargs):
953
- return approx_spectral_function(*args, f=sqrt, pos=True, **kwargs)
954
-
955
-
956
- def xlogx(x):
957
- return x * log2(x) if x > 0 else 0.0
958
-
959
-
960
- @functools.wraps(approx_spectral_function)
961
- def tr_xlogx_approx(*args, **kwargs):
962
- return approx_spectral_function(*args, f=xlogx, **kwargs)
963
-
964
-
965
- # --------------------------------------------------------------------------- #
966
- # Specific quantities #
967
- # --------------------------------------------------------------------------- #
968
-
969
-
970
- def entropy_subsys_approx(psi_ab, dims, sysa, backend=None, **kwargs):
971
- """Approximate the (Von Neumann) entropy of a pure state's subsystem.
972
-
973
- Parameters
974
- ----------
975
- psi_ab : ket
976
- Bipartite state to partially trace and find entopy of.
977
- dims : sequence of int, optional
978
- The sub dimensions of ``psi_ab``.
979
- sysa : int or sequence of int, optional
980
- Index(es) of the 'a' subsystem(s) to keep.
981
- kwargs
982
- Supplied to :func:`approx_spectral_function`.
983
- """
984
- lo = lazy_ptr_linop(psi_ab, dims=dims, sysa=sysa, backend=backend)
985
- return -tr_xlogx_approx(lo, **kwargs)
986
-
987
-
988
- def tr_sqrt_subsys_approx(psi_ab, dims, sysa, backend=None, **kwargs):
989
- """Approximate the trace sqrt of a pure state's subsystem.
990
-
991
- Parameters
992
- ----------
993
- psi_ab : ket
994
- Bipartite state to partially trace and find trace sqrt of.
995
- dims : sequence of int, optional
996
- The sub dimensions of ``psi_ab``.
997
- sysa : int or sequence of int, optional
998
- Index(es) of the 'a' subsystem(s) to keep.
999
- kwargs
1000
- Supplied to :func:`approx_spectral_function`.
1001
- """
1002
- lo = lazy_ptr_linop(psi_ab, dims=dims, sysa=sysa, backend=backend)
1003
- return tr_sqrt_approx(lo, **kwargs)
1004
-
1005
-
1006
- def norm_ppt_subsys_approx(psi_abc, dims, sysa, sysb, backend=None, **kwargs):
1007
- """Estimate the norm of the partial transpose of a pure state's subsystem."""
1008
- lo = lazy_ptr_ppt_linop(
1009
- psi_abc, dims=dims, sysa=sysa, sysb=sysb, backend=backend
1010
- )
1011
- return tr_abs_approx(lo, **kwargs)
1012
-
1013
-
1014
- def logneg_subsys_approx(psi_abc, dims, sysa, sysb, **kwargs):
1015
- """Estimate the logarithmic negativity of a pure state's subsystem.
1016
-
1017
- Parameters
1018
- ----------
1019
- psi_abc : ket
1020
- Pure tripartite state, for which estimate the entanglement between
1021
- 'a' and 'b'.
1022
- dims : sequence of int
1023
- The sub dimensions of ``psi_abc``.
1024
- sysa : int or sequence of int, optional
1025
- Index(es) of the 'a' subsystem(s) to keep, with respect to all
1026
- the dimensions, ``dims``, (i.e. pre-partial trace).
1027
- sysa : int or sequence of int, optional
1028
- Index(es) of the 'b' subsystem(s) to keep, with respect to all
1029
- the dimensions, ``dims``, (i.e. pre-partial trace).
1030
- kwargs
1031
- Supplied to :func:`approx_spectral_function`.
1032
- """
1033
- nrm = norm_ppt_subsys_approx(psi_abc, dims, sysa, sysb, **kwargs)
1034
- return max(0.0, log2(nrm))
1035
-
1036
-
1037
- def negativity_subsys_approx(psi_abc, dims, sysa, sysb, **kwargs):
1038
- """Estimate the negativity of a pure state's subsystem.
1039
-
1040
- Parameters
1041
- ----------
1042
- psi_abc : ket
1043
- Pure tripartite state, for which estimate the entanglement between
1044
- 'a' and 'b'.
1045
- dims : sequence of int
1046
- The sub dimensions of ``psi_abc``.
1047
- sysa : int or sequence of int, optional
1048
- Index(es) of the 'a' subsystem(s) to keep, with respect to all
1049
- the dimensions, ``dims``, (i.e. pre-partial trace).
1050
- sysa : int or sequence of int, optional
1051
- Index(es) of the 'b' subsystem(s) to keep, with respect to all
1052
- the dimensions, ``dims``, (i.e. pre-partial trace).
1053
- kwargs
1054
- Supplied to :func:`approx_spectral_function`.
1055
- """
1056
- nrm = norm_ppt_subsys_approx(psi_abc, dims, sysa, sysb, **kwargs)
1057
- return max(0.0, (nrm - 1) / 2)
1058
-
1059
-
1060
- def gen_bipartite_spectral_fn(exact_fn, approx_fn, pure_default):
1061
- """Generate a function that computes a spectral quantity of the subsystem
1062
- of a pure state. Automatically computes for the smaller subsystem, or
1063
- switches to the approximate method for large subsystems.
1064
-
1065
- Parameters
1066
- ----------
1067
- exact_fn : callable
1068
- The function that computes the quantity on a density matrix, with
1069
- signature: ``exact_fn(rho_a, rank=...)``.
1070
- approx_fn : callable
1071
- The function that approximately computes the quantity using a lazy
1072
- representation of the whole system. With signature
1073
- ``approx_fn(psi_ab, dims, sysa, **approx_opts)``.
1074
- pure_default : float
1075
- The default value when the whole state is the subsystem.
1076
-
1077
- Returns
1078
- -------
1079
- bipartite_spectral_fn : callable
1080
- The function, with signature:
1081
- ``(psi_ab, dims, sysa, approx_thresh=2**13, **approx_opts)``
1082
- """
1083
-
1084
- def bipartite_spectral_fn(
1085
- psi_ab, dims, sysa, approx_thresh=2**13, **approx_opts
1086
- ):
1087
- sysa = int2tup(sysa)
1088
- sz_a = prod(d for i, d in enumerate(dims) if i in sysa)
1089
- sz_b = prod(dims) // sz_a
1090
-
1091
- # pure state
1092
- if sz_b == 1:
1093
- return pure_default
1094
-
1095
- # also check if system b is smaller, since spectrum is same for both
1096
- if sz_b < sz_a:
1097
- # if so swap things around
1098
- sz_a = sz_b
1099
- sysb = [i for i in range(len(dims)) if i not in sysa]
1100
- sysa = sysb
1101
-
1102
- # check whether to use approx lanczos method
1103
- if (approx_thresh is not None) and (sz_a >= approx_thresh):
1104
- return approx_fn(psi_ab, dims, sysa, **approx_opts)
1105
-
1106
- rho_a = ptr(psi_ab, dims, sysa)
1107
- return exact_fn(rho_a)
1108
-
1109
- return bipartite_spectral_fn