Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,129 @@
1
+ """Decorator for automatically just in time compiling tensor network functions.
2
+
3
+ TODO::
4
+
5
+ - [ ] go via an intermediate pytree / array function, that could be shared
6
+ e.g. with the TNOptimizer class.
7
+
8
+ """
9
+ import functools
10
+
11
+ import autoray as ar
12
+
13
+
14
+
15
+ class AutojittedTN:
16
+ """Class to hold the ``autojit_tn`` decorated function callable.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ fn,
22
+ decorator=ar.autojit,
23
+ check_inputs=True,
24
+ **decorator_opts
25
+ ):
26
+ self.fn = fn
27
+ self.fn_store = {}
28
+ self.decorator_opts = decorator_opts
29
+ self.check_inputs = check_inputs
30
+ self.decorator = decorator
31
+
32
+ def setup_fn(self, tn, *args, **kwargs):
33
+ from quimb.tensor import TensorNetwork
34
+
35
+ @self.decorator(**self.decorator_opts)
36
+ def fn_jit(arrays):
37
+ # use separate TN to trace through function
38
+ jtn = tn.copy()
39
+
40
+ # insert the tracing arrays
41
+ for t, array in zip(jtn, arrays):
42
+ t.modify(data=array)
43
+
44
+ # run function on TN with tracing arrays
45
+ result = self.fn(jtn, *args, **kwargs)
46
+
47
+ # check for a inplace tn function
48
+ if isinstance(result, TensorNetwork):
49
+ if result is not jtn:
50
+ raise ValueError(
51
+ "If you are compiling a function that returns a"
52
+ " tensor network it needs to be inplace.")
53
+ self.inplace = True
54
+ return tuple(t.data for t in jtn)
55
+ else:
56
+ # function returns raw scalar/array(s)
57
+ self.inplace = False
58
+ return result
59
+
60
+ return fn_jit
61
+
62
+ def __call__(self, tn, *args, **kwargs):
63
+
64
+ # do we need to generate a new function for these inputs
65
+ if self.check_inputs:
66
+ key = (
67
+ tn.geometry_hash(strict_index_order=True),
68
+ tuple(args),
69
+ tuple(sorted(kwargs.items())),
70
+ )
71
+ else:
72
+ # always use the same function
73
+ key = None
74
+
75
+ if key not in self.fn_store:
76
+ self.fn_store[key] = self.setup_fn(tn, *args, **kwargs)
77
+ fn_jit = self.fn_store[key]
78
+
79
+ # run the compiled function
80
+ arrays = tuple(t.data for t in tn)
81
+ out = fn_jit(arrays)
82
+
83
+ if self.inplace:
84
+ # reinsert output arrays into input TN structure
85
+ for t, array in zip(tn, out):
86
+ t.modify(data=array)
87
+ return tn
88
+
89
+ return out
90
+
91
+
92
+ def autojit_tn(
93
+ fn=None,
94
+ decorator=ar.autojit,
95
+ check_inputs=True,
96
+ **decorator_opts,
97
+ ):
98
+ """Decorate a tensor network function to be just in time compiled / traced.
99
+ This traces solely array operations resulting in a completely static
100
+ computational graph with no side-effects. The resulting function can be
101
+ much faster if called repeatedly with only numeric changes, or hardware
102
+ accelerated if a library such as ``jax`` is used.
103
+
104
+ Parameters
105
+ ----------
106
+ fn : callable
107
+ The function to be decorated. It should take as its first argument a
108
+ :class:`~quimb.tensor.tensor_core.TensorNetwork` and return either act
109
+ inplace on it or return a raw scalar or array.
110
+ decorator : callable
111
+ The decorator to use to wrap the underlying array function. For example
112
+ ``jax.jit``. Defaults to ``autoray.autojit``.
113
+ check_inputs : bool, optional
114
+ Whether to check the inputs to the function every call to see if a new
115
+ compiled function needs to be generated. If ``False`` the same compiled
116
+ function will be used for all inputs which might be incorrect. Defaults
117
+ to ``True``.
118
+ decorator_opts
119
+ Options to pass to the decorator, e.g. ``backend`` for
120
+ ``autoray.autojit``.
121
+ """
122
+ kwargs = {
123
+ 'decorator': decorator,
124
+ 'check_inputs': check_inputs,
125
+ **decorator_opts,
126
+ }
127
+ if fn is None:
128
+ return functools.partial(autojit_tn, **kwargs)
129
+ return functools.wraps(fn)(AutojittedTN(fn, **kwargs))
@@ -0,0 +1,109 @@
1
+ """Belief propagation (BP) routines. There are three potential categorizations
2
+ of BP and each combination of them is potentially valid specific algorithm.
3
+
4
+ 1-norm vs 2-norm BP
5
+ -------------------
6
+
7
+ - 1-norm (normal): BP runs directly on the tensor network, messages have size
8
+ ``d`` where ``d`` is the size of the bond(s) connecting two tensors or
9
+ regions.
10
+ - 2-norm (quantum): BP runs on the squared tensor network, messages have size
11
+ ``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or
12
+ regions. Each local tensor or region is partially traced (over dangling
13
+ indices) with its conjugate to create a single node.
14
+
15
+
16
+ Graph vs Hypergraph BP
17
+ ----------------------
18
+
19
+ - Graph (simple): the tensor network lives on a graph, where indices either
20
+ appear on two tensors (a bond), or appear on a single tensor (are outputs).
21
+ In this case, messages are exchanged directly between tensors.
22
+ - Hypergraph: the tensor network lives on a hypergraph, where indices can
23
+ appear on any number of tensors. In this case, the update procedure is two
24
+ parts, first all 'tensor' messages are computed, these are then used in the
25
+ second step to compute all the 'index' messages, which are then fed back into
26
+ the 'tensor' message update and so forth. For 2-norm BP one likely needs to
27
+ specify which indices are outputs and should be traced over.
28
+
29
+ The hypergraph case of course includes the graph case, but since the 'index'
30
+ message update is simply the identity, it is convenient to have a separate
31
+ simpler implementation, where the standard TN bond vs physical index
32
+ definitions hold.
33
+
34
+
35
+ Dense vs Vectorized vs Lazy BP
36
+ ------------------------------
37
+
38
+ - Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all
39
+ multibonds have been fused, then each message is a vector (1-norm case) or
40
+ matrix (2-norm case).
41
+ - Vectorized: the same as the above, but all matching tensor update and message
42
+ updates are stacked and performed simultaneously. This can be enormously more
43
+ efficient for large numbers of small tensors.
44
+ - Lazy: each node is potentially a tensor network itself with arbitrary inner
45
+ structure and number of bonds connecting to other nodes. The message are
46
+ generally tensors and each update is a lazy contraction, which is potentially
47
+ much cheaper / requires less memory than forming the 'dense' node for large
48
+ tensors.
49
+
50
+ (There is also the MPS flavor where each node has a 1D structure and the
51
+ messages are matrix product states, with updates involving compression.)
52
+
53
+
54
+ Overall that gives 12 possible BP flavors, some implemented here:
55
+
56
+ - [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm
57
+ - [x] (HD2BP) hyper, dense, 2-norm
58
+ - [x] (HV1BP) hyper, vectorized, 1-norm
59
+ - [ ] (HV2BP) hyper, vectorized, 2-norm
60
+ - [ ] (HL1BP) hyper, lazy, 1-norm
61
+ - [ ] (HL2BP) hyper, lazy, 2-norm
62
+ - [x] (D1BP) simple, dense, 1-norm - simple BP for simple tensor networks
63
+ - [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm
64
+ - [ ] (V1BP) simple, vectorized, 1-norm
65
+ - [ ] (V2BP) simple, vectorized, 2-norm
66
+ - [x] (L1BP) simple, lazy, 1-norm
67
+ - [x] (L2BP) simple, lazy, 2-norm
68
+
69
+ The 2-norm methods can be used to compress bonds or estimate the 2-norm.
70
+ The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value.
71
+ Both methods can be used to compute index marginals and thus perform sampling.
72
+
73
+ The vectorized methods can be extremely fast for large numbers of small
74
+ tensors, but do currently require all dimensions to match.
75
+
76
+ The dense and lazy methods can can converge messages *locally*, i.e. only
77
+ update messages adjacent to messages which have changed.
78
+ """
79
+
80
+ from .bp_common import initialize_hyper_messages
81
+ from .d1bp import D1BP, contract_d1bp
82
+ from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp
83
+ from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp
84
+ from .hv1bp import HV1BP, contract_hv1bp, sample_hv1bp
85
+ from .l1bp import L1BP, contract_l1bp
86
+ from .l2bp import L2BP, compress_l2bp, contract_l2bp
87
+ from .regions import RegionGraph
88
+
89
+ __all__ = (
90
+ "compress_d2bp",
91
+ "compress_l2bp",
92
+ "contract_d1bp",
93
+ "contract_d2bp",
94
+ "contract_hd1bp",
95
+ "contract_hv1bp",
96
+ "contract_l1bp",
97
+ "contract_l2bp",
98
+ "D1BP",
99
+ "D2BP",
100
+ "HD1BP",
101
+ "HV1BP",
102
+ "initialize_hyper_messages",
103
+ "L1BP",
104
+ "L2BP",
105
+ "RegionGraph",
106
+ "sample_d2bp",
107
+ "sample_hd1bp",
108
+ "sample_hv1bp",
109
+ )
@@ -0,0 +1,397 @@
1
+ import functools
2
+ import operator
3
+
4
+ import autoray as ar
5
+
6
+ import quimb.tensor as qtn
7
+
8
+
9
+ def prod(xs):
10
+ """Product of all elements in ``xs``."""
11
+ return functools.reduce(operator.mul, xs)
12
+
13
+
14
+ class RollingDiffMean:
15
+ """Tracker for the absolute rolling mean of diffs between values, to
16
+ assess effective convergence of BP above actual message tolerance.
17
+ """
18
+
19
+ def __init__(self, size=16):
20
+ self.size = size
21
+ self.diffs = []
22
+ self.last_x = None
23
+ self.dxsum = 0.0
24
+
25
+ def update(self, x):
26
+ if self.last_x is not None:
27
+ dx = x - self.last_x
28
+ self.diffs.append(dx)
29
+ self.dxsum += dx / self.size
30
+ if len(self.diffs) > self.size:
31
+ dx = self.diffs.pop(0)
32
+ self.dxsum -= dx / self.size
33
+ self.last_x = x
34
+
35
+ def absmeandiff(self):
36
+ if len(self.diffs) < self.size:
37
+ return float("inf")
38
+ return abs(self.dxsum)
39
+
40
+
41
+ class BeliefPropagationCommon:
42
+ """Common interfaces for belief propagation algorithms.
43
+
44
+ Parameters
45
+ ----------
46
+ max_iterations : int, optional
47
+ The maximum number of iterations to perform.
48
+ tol : float, optional
49
+ The convergence tolerance for messages.
50
+ progbar : bool, optional
51
+ Whether to show a progress bar.
52
+ """
53
+
54
+ def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False):
55
+ if progbar:
56
+ import tqdm
57
+
58
+ pbar = tqdm.tqdm()
59
+ else:
60
+ pbar = None
61
+
62
+ try:
63
+ it = 0
64
+ rdm = RollingDiffMean()
65
+ self.converged = False
66
+ while not self.converged and it < max_iterations:
67
+ # perform a single iteration of BP
68
+ # we supply tol here for use with local convergence
69
+ nconv, ncheck, max_mdiff = self.iterate(tol=tol)
70
+ it += 1
71
+
72
+ # check rolling mean convergence
73
+ rdm.update(max_mdiff)
74
+ self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol)
75
+
76
+ if pbar is not None:
77
+ pbar.set_description(
78
+ f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}",
79
+ refresh=False,
80
+ )
81
+ pbar.update()
82
+
83
+ finally:
84
+ if pbar is not None:
85
+ pbar.close()
86
+
87
+ if tol != 0.0 and not self.converged:
88
+ import warnings
89
+
90
+ warnings.warn(
91
+ f"Belief propagation did not converge after {max_iterations} "
92
+ f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}."
93
+ )
94
+
95
+ if info is not None:
96
+ info["converged"] = self.converged
97
+ info["iterations"] = it
98
+ info["max_mdiff"] = max_mdiff
99
+ info["rolling_abs_mean_diff"] = rdm.absmeandiff()
100
+
101
+
102
+ def initialize_hyper_messages(
103
+ tn,
104
+ fill_fn=None,
105
+ smudge_factor=1e-12,
106
+ ):
107
+ """Initialize messages for belief propagation, this is equivalent to doing
108
+ a single round of belief propagation with uniform messages.
109
+
110
+ Parameters
111
+ ----------
112
+ tn : TensorNetwork
113
+ The tensor network to initialize messages for.
114
+ fill_fn : callable, optional
115
+ A function to fill the messages with, of signature ``fill_fn(shape)``.
116
+ smudge_factor : float, optional
117
+ A small number to add to the messages to avoid numerical issues.
118
+
119
+ Returns
120
+ -------
121
+ messages : dict
122
+ The initial messages. For every index and tensor id pair, there will
123
+ be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
124
+ """
125
+ from quimb.tensor.contraction import array_contract
126
+
127
+ backend = ar.infer_backend(next(t.data for t in tn))
128
+ _sum = ar.get_lib_fn(backend, "sum")
129
+
130
+ messages = {}
131
+
132
+ # compute first messages from tensors to indices
133
+ for tid, t in tn.tensor_map.items():
134
+ k_inputs = tuple(range(t.ndim))
135
+ for i, ix in enumerate(t.inds):
136
+ if fill_fn is None:
137
+ # sum over all other indices to get initial message
138
+ m = array_contract(
139
+ arrays=(t.data,),
140
+ inputs=(k_inputs,),
141
+ output=(i,),
142
+ )
143
+ # normalize and insert
144
+ messages[tid, ix] = m / _sum(m)
145
+ else:
146
+ d = t.ind_size(ix)
147
+ m = fill_fn((d,))
148
+ messages[tid, ix] = m / _sum(m)
149
+
150
+ # compute first messages from indices to tensors
151
+ for ix, tids in tn.ind_map.items():
152
+ ms = [messages[tid, ix] for tid in tids]
153
+ mp = prod(ms)
154
+ for mi, tid in zip(ms, tids):
155
+ m = mp / (mi + smudge_factor)
156
+ # normalize and insert
157
+ messages[ix, tid] = m / _sum(m)
158
+
159
+ return messages
160
+
161
+
162
+ def combine_local_contractions(
163
+ tvals,
164
+ mvals,
165
+ backend,
166
+ strip_exponent=False,
167
+ check_for_zero=True,
168
+ ):
169
+ _abs = ar.get_lib_fn(backend, "abs")
170
+ _log10 = ar.get_lib_fn(backend, "log10")
171
+
172
+ mantissa = 1
173
+ exponent = 0
174
+ for vt in tvals:
175
+ avt = _abs(vt)
176
+
177
+ if check_for_zero and (avt == 0.0):
178
+ if strip_exponent:
179
+ return 0.0, 0.0
180
+ else:
181
+ return 0.0
182
+
183
+ mantissa = mantissa * (vt / avt)
184
+ exponent = exponent + _log10(avt)
185
+ for mt in mvals:
186
+ amt = _abs(mt)
187
+ mantissa = mantissa / (mt / amt)
188
+ exponent = exponent - _log10(amt)
189
+
190
+ if strip_exponent:
191
+ return mantissa, exponent
192
+ else:
193
+ return mantissa * 10**exponent
194
+
195
+
196
+ def contract_hyper_messages(
197
+ tn,
198
+ messages,
199
+ strip_exponent=False,
200
+ backend=None,
201
+ ):
202
+ """Estimate the contraction of ``tn`` given ``messages``, via the
203
+ exponential of the Bethe free entropy.
204
+ """
205
+ tvals = []
206
+ mvals = []
207
+
208
+ for tid, t in tn.tensor_map.items():
209
+ if backend is None:
210
+ backend = ar.infer_backend(t.data)
211
+
212
+ arrays = [t.data]
213
+ inputs = [range(t.ndim)]
214
+ for i, ix in enumerate(t.inds):
215
+ arrays.append(messages[ix, tid])
216
+ inputs.append((i,))
217
+
218
+ # local message overlap correction
219
+ mvals.append(
220
+ qtn.array_contract(
221
+ (messages[tid, ix], messages[ix, tid]),
222
+ inputs=((0,), (0,)),
223
+ output=(),
224
+ )
225
+ )
226
+
227
+ # local factor free entropy
228
+ tvals.append(qtn.array_contract(arrays, inputs, output=()))
229
+
230
+ for ix, tids in tn.ind_map.items():
231
+ arrays = tuple(messages[tid, ix] for tid in tids)
232
+ inputs = tuple((0,) for _ in tids)
233
+ # local variable free entropy
234
+ tvals.append(qtn.array_contract(arrays, inputs, output=()))
235
+
236
+ return combine_local_contractions(
237
+ tvals, mvals, backend, strip_exponent=strip_exponent
238
+ )
239
+
240
+
241
+ def compute_index_marginal(tn, ind, messages):
242
+ """Compute the marginal for a single index given ``messages``.
243
+
244
+ Parameters
245
+ ----------
246
+ tn : TensorNetwork
247
+ The tensor network to compute the marginal for.
248
+ ind : int
249
+ The index to compute the marginal for.
250
+ messages : dict
251
+ The messages to use, which should match ``tn``.
252
+
253
+ Returns
254
+ -------
255
+ marginal : array_like
256
+ The marginal probability distribution for the index ``ind``.
257
+ """
258
+ m = prod(messages[tid, ind] for tid in tn.ind_map[ind])
259
+ return m / ar.do("sum", m)
260
+
261
+
262
+ def compute_tensor_marginal(tn, tid, messages):
263
+ """Compute the marginal for the region surrounding a single tensor/factor
264
+ given ``messages``.
265
+
266
+ Parameters
267
+ ----------
268
+ tn : TensorNetwork
269
+ The tensor network to compute the marginal for.
270
+ tid : int
271
+ The tensor id to compute the marginal for.
272
+ messages : dict
273
+ The messages to use, which should match ``tn``.
274
+
275
+ Returns
276
+ -------
277
+ marginal : array_like
278
+ The marginal probability distribution for the tensor/factor ``tid``.
279
+ """
280
+ t = tn.tensor_map[tid]
281
+
282
+ output = tuple(range(t.ndim))
283
+ inputs = [output]
284
+ arrays = [t.data]
285
+
286
+ for i, ix in enumerate(t.inds):
287
+ mix = prod(
288
+ messages[otid, ix] for otid in tn.ind_map[ix] if otid != tid
289
+ )
290
+ inputs.append((i,))
291
+ arrays.append(mix)
292
+
293
+ m = qtn.array_contract(
294
+ arrays=arrays,
295
+ inputs=inputs,
296
+ output=output,
297
+ )
298
+
299
+ return m / ar.do("sum", m)
300
+
301
+
302
+ def compute_all_index_marginals_from_messages(tn, messages):
303
+ """Compute all index marginals from belief propagation messages.
304
+
305
+ Parameters
306
+ ----------
307
+ tn : TensorNetwork
308
+ The tensor network to compute marginals for.
309
+ messages : dict
310
+ The belief propagation messages.
311
+
312
+ Returns
313
+ -------
314
+ marginals : dict
315
+ The marginals for each index.
316
+ """
317
+ return {ix: compute_index_marginal(tn, ix, messages) for ix in tn.ind_map}
318
+
319
+
320
+ def maybe_get_thread_pool(thread_pool):
321
+ """Get a thread pool if requested."""
322
+ if thread_pool is False:
323
+ return None
324
+
325
+ if thread_pool is True:
326
+ import quimb as qu
327
+
328
+ return qu.get_thread_pool()
329
+
330
+ if isinstance(thread_pool, int):
331
+ import quimb as qu
332
+
333
+ return qu.get_thread_pool(thread_pool)
334
+
335
+ return thread_pool
336
+
337
+
338
+ def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True):
339
+ """For lazy BP algorithms, create the data structures describing the
340
+ effective graph of the lazily grouped 'sites' given by ``site_tags``.
341
+ """
342
+ if site_tags is None:
343
+ site_tags = set(tn.site_tags)
344
+ else:
345
+ site_tags = set(site_tags)
346
+
347
+ edges = {}
348
+ neighbors = {}
349
+ local_tns = {}
350
+ touch_map = {}
351
+
352
+ for ix in tn.ind_map:
353
+ ts = tn._inds_get(ix)
354
+ tags = {tag for t in ts for tag in t.tags if tag in site_tags}
355
+ if len(tags) >= 2:
356
+ i, j = tuple(sorted(tags))
357
+
358
+ if (i, j) in edges:
359
+ # already processed this edge
360
+ continue
361
+
362
+ # add to neighbor map
363
+ neighbors.setdefault(i, []).append(j)
364
+ neighbors.setdefault(j, []).append(i)
365
+
366
+ # get local TNs and compute bonds between them,
367
+ # rank simplify here also to prepare for contractions
368
+ try:
369
+ tn_i = local_tns[i]
370
+ except KeyError:
371
+ tn_i = local_tns[i] = tn.select(i, virtual=False)
372
+ if rank_simplify:
373
+ tn_i.rank_simplify_()
374
+ try:
375
+ tn_j = local_tns[j]
376
+ except KeyError:
377
+ tn_j = local_tns[j] = tn.select(j, virtual=False)
378
+ if rank_simplify:
379
+ tn_j.rank_simplify_()
380
+
381
+ edges[i, j] = tuple(qtn.bonds(tn_i, tn_j))
382
+
383
+ for i, j in edges:
384
+ touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i)
385
+ touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j)
386
+
387
+ if len(local_tns) != len(site_tags):
388
+ # handle potentially disconnected sites
389
+ for i in sorted(site_tags):
390
+ try:
391
+ tn_i = local_tns[i] = tn.select(i, virtual=False)
392
+ if rank_simplify:
393
+ tn_i.rank_simplify_()
394
+ except KeyError:
395
+ pass
396
+
397
+ return edges, neighbors, local_tns, touch_map