Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
1
+ import itertools
2
+ import importlib
3
+
4
+ import pytest
5
+
6
+ import quimb as qu
7
+ import quimb.tensor as qtn
8
+
9
+ found_torch = importlib.util.find_spec("torch") is not None
10
+
11
+ pytorch_case = pytest.param(
12
+ "torch",
13
+ marks=pytest.mark.skipif(not found_torch, reason="pytorch not installed"),
14
+ )
15
+
16
+
17
+ class TestLocalHam2DConstruct:
18
+ @pytest.mark.parametrize("H2_type", ["default", "manual"])
19
+ @pytest.mark.parametrize("H1_type", [None, "default", "manual"])
20
+ @pytest.mark.parametrize("Lx", [3, 4])
21
+ @pytest.mark.parametrize("Ly", [3, 4])
22
+ def test_construct(self, Lx, Ly, H2_type, H1_type):
23
+ import matplotlib
24
+ from matplotlib import pyplot as plt
25
+
26
+ matplotlib.use("Template")
27
+
28
+ if H2_type == "default":
29
+ H2 = qu.rand_herm(4)
30
+ elif H2_type == "manual":
31
+ H2 = dict()
32
+ for i, j in itertools.product(range(Lx), range(Ly)):
33
+ if i + 1 < Lx:
34
+ H2[(i, j), (i + 1, j)] = qu.rand_herm(4)
35
+ if j + 1 < Ly:
36
+ H2[(i, j), (i, j + 1)] = qu.rand_herm(4)
37
+
38
+ if H1_type is None:
39
+ H1 = None
40
+ elif H1_type == "default":
41
+ H1 = qu.rand_herm(2)
42
+ elif H1_type == "manual":
43
+ H1 = dict()
44
+ for i, j in itertools.product(range(Lx), range(Ly)):
45
+ H1[i, j] = qu.rand_herm(2)
46
+
47
+ ham = qtn.LocalHam2D(Lx, Ly, H2, H1)
48
+ assert len(ham.terms) == 2 * Lx * Ly - Lx - Ly
49
+
50
+ # check that terms are being cached if possible
51
+ if (H2_type == "default") and (H1_type is None):
52
+ assert len({id(x) for x in ham.terms.values()}) == 1
53
+
54
+ print(ham)
55
+ fig, ax = ham.draw()
56
+ plt.close(fig)
57
+
58
+ @pytest.mark.parametrize("Lx", [4, 5])
59
+ @pytest.mark.parametrize("Ly", [4, 5])
60
+ @pytest.mark.parametrize(
61
+ "order", [None, "sort", "random", "smallest_last"]
62
+ )
63
+ def test_ordering(self, Lx, Ly, order):
64
+ ham = qtn.ham_2d_j1j2(Lx, Ly)
65
+ assert len(ham.terms) == 2 * Lx * Ly - Lx - Ly + 2 * (Lx - 1) * (
66
+ Ly - 1
67
+ )
68
+ ordering = ham.get_auto_ordering(order)
69
+ assert len(ordering) == len(ham.terms)
70
+ assert set(ordering) == set(ham.terms)
71
+ assert tuple(ordering) != tuple(ham.terms)
72
+
73
+ # make sure first four pairs are in same commuting group at least
74
+ first_four_pairs = tuple(itertools.chain(*ordering[:4]))
75
+ assert len(first_four_pairs) == len(set(first_four_pairs))
76
+
77
+
78
+ class TestSimpleUpdate:
79
+ @pytest.mark.parametrize("backend", ["numpy", pytorch_case])
80
+ def test_heis_small(self, backend):
81
+ Lx = 3
82
+ Ly = 4
83
+ D = 2
84
+
85
+ ham = qtn.ham_2d_heis(Lx, Ly)
86
+ psi0 = qtn.PEPS.rand(Lx, Ly, D)
87
+
88
+ def to_backend(x):
89
+ import autoray
90
+
91
+ return autoray.do("array", x, like=backend)
92
+
93
+ psi0.apply_to_arrays(to_backend)
94
+ ham.apply_to_arrays(to_backend)
95
+
96
+ su = qtn.SimpleUpdate(
97
+ psi0,
98
+ ham,
99
+ progbar=True,
100
+ keep_best=True,
101
+ compute_energy_every=10,
102
+ ordering="largest_first",
103
+ )
104
+
105
+ su.evolve(33, tau=0.3)
106
+ su.state = su.best["state"]
107
+ su.evolve(33, tau=0.1)
108
+ su.state = su.best["state"]
109
+ su.evolve(33, tau=0.03)
110
+ su.state = su.best["state"]
111
+
112
+ assert su.best["energy"] < -6.25
113
+
114
+
115
+ class TestFullUpdate:
116
+ @pytest.mark.parametrize("backend", ["numpy", pytorch_case])
117
+ def test_heis_small(self, backend):
118
+ Lx = 3
119
+ Ly = 4
120
+ D = 2
121
+
122
+ psi0 = qtn.PEPS.rand(Lx, Ly, D)
123
+ ham = qtn.ham_2d_heis(Lx, Ly)
124
+
125
+ def to_backend(x):
126
+ import autoray
127
+
128
+ return autoray.do("array", x, like=backend)
129
+
130
+ psi0.apply_to_arrays(to_backend)
131
+ ham.apply_to_arrays(to_backend)
132
+
133
+ su = qtn.FullUpdate(
134
+ psi0, ham, progbar=True, keep_best=True, compute_energy_every=1
135
+ )
136
+
137
+ su.evolve(33, tau=0.3)
138
+ su.state = su.best["state"]
139
+ su.evolve(33, tau=0.1)
140
+ su.state = su.best["state"]
141
+ su.evolve(33, tau=0.03)
142
+ su.state = su.best["state"]
143
+
144
+ assert su.best["energy"] < -6.30
@@ -0,0 +1,123 @@
1
+ import pytest
2
+ import autoray as ar
3
+
4
+ import quimb as qu
5
+ import quimb.tensor as qtn
6
+
7
+
8
+ class TestTensorNetwork3D:
9
+ def test_cyclic_basic(self):
10
+ tn = qtn.TN3D_empty(Lx=3, Ly=4, Lz=5, D=2, cyclic=True)
11
+ assert tn.is_cyclic_x()
12
+ assert tn.is_cyclic_y()
13
+ assert tn.is_cyclic_z()
14
+ assert tn.num_indices == 3 * tn.nsites
15
+ tn = qtn.TN3D_empty(Lx=3, Ly=4, Lz=5, D=2, cyclic=(False, False, True))
16
+ assert not tn.is_cyclic_x()
17
+ assert not tn.is_cyclic_y()
18
+ assert tn.is_cyclic_z()
19
+ assert tn.num_indices == 3 * tn.nsites - (tn.Lx * tn.Lz) - (
20
+ tn.Ly * tn.Lz
21
+ )
22
+ tn = qtn.TN3D_empty(Lx=3, Ly=4, Lz=5, D=2, cyclic=(False, True, False))
23
+ assert not tn.is_cyclic_x()
24
+ assert tn.is_cyclic_y()
25
+ assert not tn.is_cyclic_z()
26
+ assert tn.num_indices == 3 * tn.nsites - (tn.Lx * tn.Ly) - (
27
+ tn.Ly * tn.Lz
28
+ )
29
+ tn = qtn.TN3D_empty(Lx=3, Ly=4, Lz=5, D=2, cyclic=(True, False, False))
30
+ assert tn.is_cyclic_x()
31
+ assert not tn.is_cyclic_y()
32
+ assert not tn.is_cyclic_z()
33
+ assert tn.num_indices == 3 * tn.nsites - (tn.Lx * tn.Ly) - (
34
+ tn.Lx * tn.Lz
35
+ )
36
+
37
+
38
+ class Test3DManualContract:
39
+ @pytest.mark.parametrize("canonize", [False, True])
40
+ def test_contract_boundary_ising_model(self, canonize):
41
+ L = 5
42
+ beta = 0.3
43
+ fex = -2.7654417752878
44
+ tn = qtn.TN3D_classical_ising_partition_function(L, L, L, beta=beta)
45
+ Z = tn.contract_boundary(max_bond=8, canonize=canonize)
46
+ f = -qu.log(Z) / (L**3 * beta)
47
+ assert f == pytest.approx(fex, rel=1e-3)
48
+
49
+ @pytest.mark.parametrize("dims", [(10, 4, 3), (4, 3, 10), (3, 10, 4)])
50
+ def test_contract_boundary_stopping_criterion(self, dims):
51
+ tn = qtn.TN3D_from_fill_fn(
52
+ lambda shape: ar.lazy.Variable(shape=shape, backend="numpy"),
53
+ *dims,
54
+ D=2,
55
+ )
56
+ tn.contract_boundary_(
57
+ 4, cutoff=0.0, final_contract=False, progbar=True
58
+ )
59
+ assert tn.max_bond() == 4
60
+ assert 32 <= tn.num_tensors <= 40
61
+
62
+ @pytest.mark.parametrize("lazy", [False, True])
63
+ def test_coarse_grain_basics(self, lazy):
64
+ tn = qtn.TN3D_from_fill_fn(
65
+ lambda shape: ar.lazy.Variable(shape, backend="numpy"),
66
+ Lx=6,
67
+ Ly=7,
68
+ Lz=8,
69
+ D=2,
70
+ )
71
+ tncg = tn.coarse_grain_hotrg("x", max_bond=3, cutoff=0.0, lazy=lazy)
72
+ assert (tncg.Lx, tncg.Ly, tncg.Lz) == (3, 7, 8)
73
+ assert not tncg.outer_inds()
74
+ assert tncg.max_bond() == 3
75
+ assert "I4,0,0" not in tncg.tag_map
76
+ assert "X3" not in tncg.tag_map
77
+
78
+ tncg = tn.coarse_grain_hotrg("y", max_bond=3, cutoff=0.0, lazy=lazy)
79
+ assert (tncg.Lx, tncg.Ly, tncg.Lz) == (6, 4, 8)
80
+ assert not tncg.outer_inds()
81
+ assert tncg.max_bond() == 3
82
+ assert "I0,5,0" not in tncg.tag_map
83
+ assert "Y4" not in tncg.tag_map
84
+
85
+ tncg = tn.coarse_grain_hotrg("z", max_bond=3, cutoff=0.0, lazy=lazy)
86
+ assert (tncg.Lx, tncg.Ly, tncg.Lz) == (6, 7, 4)
87
+ assert "I0,0,5" not in tncg.tag_map
88
+ assert "Z4" not in tncg.tag_map
89
+
90
+ def test_contract_hotrg_ising_model(self):
91
+ L = 5
92
+ beta = 0.3
93
+ fex = -2.7654417752878
94
+ tn = qtn.TN3D_classical_ising_partition_function(L, L, L, beta=beta)
95
+ tn.contract_hotrg_(max_bond=4, progbar=True, equalize_norms=1.0)
96
+ Z = tn.item() * 10**tn.exponent
97
+ f = -qu.log(Z) / (L**3 * beta)
98
+ assert f == pytest.approx(fex, rel=1e-2)
99
+
100
+ @pytest.mark.parametrize("cyclicx", [False, True])
101
+ @pytest.mark.parametrize("cyclicy", [False, True])
102
+ @pytest.mark.parametrize("cyclicz", [False, True])
103
+ @pytest.mark.parametrize("mode", ["hotrg", "ctmrg"])
104
+ def test_contract_cyclic(self, cyclicx, cyclicy, cyclicz, mode):
105
+ Lx, Ly, Lz = 3, 4, 5
106
+ chi = 3
107
+ tn = qtn.TN3D_from_fill_fn(
108
+ lambda shape: ar.lazy.Variable(shape=shape, backend="numpy"),
109
+ Lx,
110
+ Ly,
111
+ Lz,
112
+ D=2,
113
+ cyclic=(cyclicx, cyclicy, cyclicz),
114
+ )
115
+ if mode == "hotrg":
116
+ lZ = tn.contract_hotrg(max_bond=chi, cutoff=0.0)
117
+ elif mode == "ctmrg":
118
+ lZ = tn.contract_ctmrg(max_bond=chi, cutoff=0.0)
119
+
120
+ if any((cyclicx, cyclicy, cyclicz)):
121
+ assert lZ.history_max_size() < 2**16
122
+ else:
123
+ assert lZ.history_max_size() < 2**13
@@ -0,0 +1,226 @@
1
+ import pytest
2
+ from numpy.testing import assert_allclose
3
+
4
+ import quimb.tensor as qtn
5
+
6
+
7
+ @pytest.mark.parametrize("which_A", ["upper", "lower"])
8
+ @pytest.mark.parametrize("contract", [True, False])
9
+ @pytest.mark.parametrize("inplace", [True, False])
10
+ def test_tensor_network_apply_op_vec(which_A, contract, inplace):
11
+ A = qtn.TN_from_edges_rand(
12
+ qtn.edges_2d_square(3, 2),
13
+ D=2,
14
+ phys_dim=2,
15
+ site_ind_id=("k{}", "b{}"),
16
+ dtype=complex,
17
+ )
18
+ x = qtn.TN_from_edges_rand(
19
+ qtn.edges_2d_square(3, 2),
20
+ D=3,
21
+ phys_dim=2,
22
+ site_ind_id="x{}",
23
+ dtype=complex,
24
+ )
25
+
26
+ Ad = A.to_dense()
27
+ if which_A == "upper":
28
+ Ad = Ad.T
29
+ xd = x.to_dense()
30
+ C = Ad @ xd
31
+
32
+ Ax = qtn.tensor_network_apply_op_vec(
33
+ A,
34
+ x,
35
+ which_A,
36
+ inplace=inplace,
37
+ contract=contract,
38
+ )
39
+
40
+ if contract:
41
+ # checks fusing
42
+ assert Ax.num_indices == x.num_indices
43
+
44
+ if inplace:
45
+ assert Ax is x
46
+ else:
47
+ assert isinstance(Ax, x.__class__)
48
+ assert Ax.site_ind_id == x.site_ind_id
49
+
50
+ assert_allclose(Ax.to_dense(), C)
51
+
52
+
53
+ @pytest.mark.parametrize("which_A", ["upper", "lower"])
54
+ @pytest.mark.parametrize("which_B", ["upper", "lower"])
55
+ @pytest.mark.parametrize("contract", [True, False])
56
+ @pytest.mark.parametrize("inplace", [True, False])
57
+ def test_tensor_network_apply_op_op(which_A, which_B, contract, inplace):
58
+ A = qtn.TN_from_edges_rand(
59
+ qtn.edges_2d_square(3, 2),
60
+ D=2,
61
+ phys_dim=2,
62
+ site_ind_id=("k{}", "b{}"),
63
+ dtype=complex,
64
+ )
65
+ B = qtn.TN_from_edges_rand(
66
+ qtn.edges_2d_square(3, 2),
67
+ D=3,
68
+ phys_dim=2,
69
+ site_ind_id=("x{}", "y{}"),
70
+ dtype=complex,
71
+ )
72
+ Ad = A.to_dense()
73
+ if which_A == "upper":
74
+ Ad = Ad.T
75
+ Bd = B.to_dense()
76
+ if which_B == "lower":
77
+ Bd = Bd.T
78
+ C = Ad @ Bd
79
+ if which_B == "lower":
80
+ C = C.T
81
+
82
+ AB = qtn.tensor_network_apply_op_op(
83
+ A,
84
+ B,
85
+ which_A,
86
+ which_B,
87
+ inplace=inplace,
88
+ contract=contract,
89
+ )
90
+
91
+ if contract:
92
+ # checks fusing
93
+ assert AB.num_indices == B.num_indices
94
+
95
+ if inplace:
96
+ assert AB is B
97
+ else:
98
+ assert isinstance(AB, B.__class__)
99
+ assert AB.upper_ind_id == B.upper_ind_id
100
+ assert AB.lower_ind_id == B.lower_ind_id
101
+
102
+ assert_allclose(AB.to_dense(), C)
103
+
104
+
105
+ def test_gate_with_op():
106
+ A = qtn.MPO_rand(5, 3, dtype=complex)
107
+ x = qtn.MPS_rand_state(5, 3, dtype=complex)
108
+ y = A.to_dense() @ x.to_dense()
109
+ x.gate_with_op_lazy_(A)
110
+ assert_allclose(x.to_dense(), y)
111
+
112
+
113
+ def test_gate_sandwich_with_op():
114
+ B = qtn.MPO_rand(5, 3, dtype=complex)
115
+ A = qtn.MPO_rand(5, 3, dtype=complex)
116
+ y = A.to_dense() @ B.to_dense() @ A.to_dense().conj().T
117
+ B.gate_sandwich_with_op_lazy_(A)
118
+ assert_allclose(B.to_dense(), y)
119
+
120
+
121
+ def test_normalize_simple():
122
+ psi = qtn.PEPS.rand(3, 3, 2, dtype=complex)
123
+ gauges = {}
124
+ psi.gauge_all_simple_(100, 5e-6, gauges=gauges)
125
+ psi.normalize_simple(gauges)
126
+
127
+ for where in [
128
+ [(0, 0)],
129
+ [(1, 1), (1, 2)],
130
+ [(0, 0), (0, 1), (0, 2), (1, 2), (2, 2), (2, 1)],
131
+ ]:
132
+ tags = [psi.site_tag(w) for w in where]
133
+ k = psi.select_any(tags, virtual=False)
134
+ k.gauge_simple_insert(gauges)
135
+
136
+ assert k.H @ k == pytest.approx(1.0)
137
+
138
+
139
+ def test_local_expectation_loop_expansions():
140
+ import quimb as qu
141
+
142
+ edges = [(0, 1), (0, 2), (2, 3), (1, 3), (2, 4), (3, 5), (4, 5)]
143
+ psi = qtn.TN_from_edges_rand(
144
+ edges,
145
+ D=3,
146
+ phys_dim=2,
147
+ seed=42,
148
+ dist="uniform",
149
+ loc=-0.1,
150
+ )
151
+ G = qu.rand_herm(4)
152
+ where = (0, 2)
153
+ o_ex = psi.local_expectation_exact(G, where)
154
+
155
+ gauges = {}
156
+ psi.gauge_all_simple_(100, 5e-6, gauges=gauges)
157
+ psi.normalize_simple(gauges)
158
+
159
+ # test loop generation per term
160
+ o_c0 = psi.local_expectation_loop_expansion(
161
+ G, where, loops=0, gauges=gauges
162
+ )
163
+ assert o_c0 == pytest.approx(
164
+ psi.local_expectation_cluster(G, where, gauges=gauges)
165
+ )
166
+ assert o_ex == pytest.approx(o_c0, rel=0.5, abs=0.01)
167
+ o_c1 = psi.local_expectation_loop_expansion(
168
+ G, where, loops=4, gauges=gauges
169
+ )
170
+ assert o_ex == pytest.approx(o_c1, rel=0.5, abs=0.01)
171
+ o_c2 = psi.local_expectation_loop_expansion(
172
+ G, where, loops=6, gauges=gauges
173
+ )
174
+ assert o_ex == pytest.approx(o_c2, rel=0.4, abs=0.01)
175
+
176
+ # test manual loops supply
177
+ loops = tuple(psi.gen_paths_loops(6))
178
+ o_cl = psi.local_expectation_loop_expansion(
179
+ G, where, loops=loops, gauges=gauges
180
+ )
181
+ assert o_ex == pytest.approx(o_cl, rel=0.4, abs=0.01)
182
+
183
+
184
+ def test_local_expectation_cluster_expansions():
185
+ import quimb as qu
186
+
187
+ edges = [(0, 1), (0, 2), (2, 3), (1, 3), (2, 4), (3, 5), (4, 5)]
188
+ psi = qtn.TN_from_edges_rand(
189
+ edges,
190
+ D=3,
191
+ phys_dim=2,
192
+ seed=42,
193
+ dist="uniform",
194
+ loc=-0.1,
195
+ )
196
+ G = qu.rand_herm(4)
197
+ where = (0, 2)
198
+ o_ex = psi.local_expectation_exact(G, where)
199
+
200
+ gauges = {}
201
+ psi.gauge_all_simple_(100, 5e-6, gauges=gauges)
202
+ psi.normalize_simple(gauges)
203
+
204
+ # test cluster generation per term
205
+ o_c0 = psi.local_expectation_cluster_expansion(
206
+ G, where, clusters=0, gauges=gauges
207
+ )
208
+ assert o_c0 == pytest.approx(
209
+ psi.local_expectation_cluster(G, where, gauges=gauges)
210
+ )
211
+ assert o_ex == pytest.approx(o_c0, rel=0.5, abs=0.01)
212
+ o_c1 = psi.local_expectation_cluster_expansion(
213
+ G, where, clusters=4, gauges=gauges
214
+ )
215
+ assert o_ex == pytest.approx(o_c1, rel=0.5, abs=0.01)
216
+ o_c2 = psi.local_expectation_cluster_expansion(
217
+ G, where, clusters=6, gauges=gauges
218
+ )
219
+ assert o_ex == pytest.approx(o_c2, rel=0.4, abs=0.01)
220
+
221
+ # test manual clusters supply
222
+ clusters = tuple(psi.gen_regions(4))
223
+ o_cl = psi.local_expectation_cluster_expansion(
224
+ G, where, clusters=clusters, gauges=gauges
225
+ )
226
+ assert o_ex == pytest.approx(o_cl, rel=0.4, abs=0.01)