Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +3 -0
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
- trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
- trajectree/quimb/docs/conf.py +158 -0
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
- trajectree/quimb/quimb/__init__.py +507 -0
- trajectree/quimb/quimb/calc.py +1491 -0
- trajectree/quimb/quimb/core.py +2279 -0
- trajectree/quimb/quimb/evo.py +712 -0
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +129 -0
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
- trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
- trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
- trajectree/quimb/quimb/experimental/schematic.py +7 -0
- trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
- trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
- trajectree/quimb/quimb/gates.py +36 -0
- trajectree/quimb/quimb/gen/__init__.py +2 -0
- trajectree/quimb/quimb/gen/operators.py +1167 -0
- trajectree/quimb/quimb/gen/rand.py +713 -0
- trajectree/quimb/quimb/gen/states.py +479 -0
- trajectree/quimb/quimb/linalg/__init__.py +6 -0
- trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
- trajectree/quimb/quimb/linalg/autoblock.py +258 -0
- trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
- trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
- trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
- trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
- trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
- trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
- trajectree/quimb/quimb/schematic.py +1518 -0
- trajectree/quimb/quimb/tensor/__init__.py +401 -0
- trajectree/quimb/quimb/tensor/array_ops.py +610 -0
- trajectree/quimb/quimb/tensor/circuit.py +4824 -0
- trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
- trajectree/quimb/quimb/tensor/contraction.py +336 -0
- trajectree/quimb/quimb/tensor/decomp.py +1255 -0
- trajectree/quimb/quimb/tensor/drawing.py +1646 -0
- trajectree/quimb/quimb/tensor/fitting.py +385 -0
- trajectree/quimb/quimb/tensor/geometry.py +583 -0
- trajectree/quimb/quimb/tensor/interface.py +114 -0
- trajectree/quimb/quimb/tensor/networking.py +1058 -0
- trajectree/quimb/quimb/tensor/optimize.py +1818 -0
- trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
- trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
- trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
- trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
- trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
- trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
- trajectree/quimb/quimb/utils.py +892 -0
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +501 -0
- trajectree/quimb/tests/test_calc.py +788 -0
- trajectree/quimb/tests/test_core.py +847 -0
- trajectree/quimb/tests/test_evo.py +565 -0
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +361 -0
- trajectree/quimb/tests/test_gen/test_rand.py +296 -0
- trajectree/quimb/tests/test_gen/test_states.py +261 -0
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
- trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
- trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
- trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
- trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
- trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
- trajectree/quimb/tests/test_utils.py +85 -0
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
- trajectree-0.0.1.dist-info/RECORD +126 -0
- trajectree-0.0.0.dist-info/RECORD +0 -16
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Decorator for automatically just in time compiling tensor network functions.
|
|
2
|
+
|
|
3
|
+
TODO::
|
|
4
|
+
|
|
5
|
+
- [ ] go via an intermediate pytree / array function, that could be shared
|
|
6
|
+
e.g. with the TNOptimizer class.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
import functools
|
|
10
|
+
|
|
11
|
+
import autoray as ar
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AutojittedTN:
|
|
16
|
+
"""Class to hold the ``autojit_tn`` decorated function callable.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
fn,
|
|
22
|
+
decorator=ar.autojit,
|
|
23
|
+
check_inputs=True,
|
|
24
|
+
**decorator_opts
|
|
25
|
+
):
|
|
26
|
+
self.fn = fn
|
|
27
|
+
self.fn_store = {}
|
|
28
|
+
self.decorator_opts = decorator_opts
|
|
29
|
+
self.check_inputs = check_inputs
|
|
30
|
+
self.decorator = decorator
|
|
31
|
+
|
|
32
|
+
def setup_fn(self, tn, *args, **kwargs):
|
|
33
|
+
from quimb.tensor import TensorNetwork
|
|
34
|
+
|
|
35
|
+
@self.decorator(**self.decorator_opts)
|
|
36
|
+
def fn_jit(arrays):
|
|
37
|
+
# use separate TN to trace through function
|
|
38
|
+
jtn = tn.copy()
|
|
39
|
+
|
|
40
|
+
# insert the tracing arrays
|
|
41
|
+
for t, array in zip(jtn, arrays):
|
|
42
|
+
t.modify(data=array)
|
|
43
|
+
|
|
44
|
+
# run function on TN with tracing arrays
|
|
45
|
+
result = self.fn(jtn, *args, **kwargs)
|
|
46
|
+
|
|
47
|
+
# check for a inplace tn function
|
|
48
|
+
if isinstance(result, TensorNetwork):
|
|
49
|
+
if result is not jtn:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"If you are compiling a function that returns a"
|
|
52
|
+
" tensor network it needs to be inplace.")
|
|
53
|
+
self.inplace = True
|
|
54
|
+
return tuple(t.data for t in jtn)
|
|
55
|
+
else:
|
|
56
|
+
# function returns raw scalar/array(s)
|
|
57
|
+
self.inplace = False
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
return fn_jit
|
|
61
|
+
|
|
62
|
+
def __call__(self, tn, *args, **kwargs):
|
|
63
|
+
|
|
64
|
+
# do we need to generate a new function for these inputs
|
|
65
|
+
if self.check_inputs:
|
|
66
|
+
key = (
|
|
67
|
+
tn.geometry_hash(strict_index_order=True),
|
|
68
|
+
tuple(args),
|
|
69
|
+
tuple(sorted(kwargs.items())),
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
# always use the same function
|
|
73
|
+
key = None
|
|
74
|
+
|
|
75
|
+
if key not in self.fn_store:
|
|
76
|
+
self.fn_store[key] = self.setup_fn(tn, *args, **kwargs)
|
|
77
|
+
fn_jit = self.fn_store[key]
|
|
78
|
+
|
|
79
|
+
# run the compiled function
|
|
80
|
+
arrays = tuple(t.data for t in tn)
|
|
81
|
+
out = fn_jit(arrays)
|
|
82
|
+
|
|
83
|
+
if self.inplace:
|
|
84
|
+
# reinsert output arrays into input TN structure
|
|
85
|
+
for t, array in zip(tn, out):
|
|
86
|
+
t.modify(data=array)
|
|
87
|
+
return tn
|
|
88
|
+
|
|
89
|
+
return out
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def autojit_tn(
|
|
93
|
+
fn=None,
|
|
94
|
+
decorator=ar.autojit,
|
|
95
|
+
check_inputs=True,
|
|
96
|
+
**decorator_opts,
|
|
97
|
+
):
|
|
98
|
+
"""Decorate a tensor network function to be just in time compiled / traced.
|
|
99
|
+
This traces solely array operations resulting in a completely static
|
|
100
|
+
computational graph with no side-effects. The resulting function can be
|
|
101
|
+
much faster if called repeatedly with only numeric changes, or hardware
|
|
102
|
+
accelerated if a library such as ``jax`` is used.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
fn : callable
|
|
107
|
+
The function to be decorated. It should take as its first argument a
|
|
108
|
+
:class:`~quimb.tensor.tensor_core.TensorNetwork` and return either act
|
|
109
|
+
inplace on it or return a raw scalar or array.
|
|
110
|
+
decorator : callable
|
|
111
|
+
The decorator to use to wrap the underlying array function. For example
|
|
112
|
+
``jax.jit``. Defaults to ``autoray.autojit``.
|
|
113
|
+
check_inputs : bool, optional
|
|
114
|
+
Whether to check the inputs to the function every call to see if a new
|
|
115
|
+
compiled function needs to be generated. If ``False`` the same compiled
|
|
116
|
+
function will be used for all inputs which might be incorrect. Defaults
|
|
117
|
+
to ``True``.
|
|
118
|
+
decorator_opts
|
|
119
|
+
Options to pass to the decorator, e.g. ``backend`` for
|
|
120
|
+
``autoray.autojit``.
|
|
121
|
+
"""
|
|
122
|
+
kwargs = {
|
|
123
|
+
'decorator': decorator,
|
|
124
|
+
'check_inputs': check_inputs,
|
|
125
|
+
**decorator_opts,
|
|
126
|
+
}
|
|
127
|
+
if fn is None:
|
|
128
|
+
return functools.partial(autojit_tn, **kwargs)
|
|
129
|
+
return functools.wraps(fn)(AutojittedTN(fn, **kwargs))
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Belief propagation (BP) routines. There are three potential categorizations
|
|
2
|
+
of BP and each combination of them is potentially valid specific algorithm.
|
|
3
|
+
|
|
4
|
+
1-norm vs 2-norm BP
|
|
5
|
+
-------------------
|
|
6
|
+
|
|
7
|
+
- 1-norm (normal): BP runs directly on the tensor network, messages have size
|
|
8
|
+
``d`` where ``d`` is the size of the bond(s) connecting two tensors or
|
|
9
|
+
regions.
|
|
10
|
+
- 2-norm (quantum): BP runs on the squared tensor network, messages have size
|
|
11
|
+
``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or
|
|
12
|
+
regions. Each local tensor or region is partially traced (over dangling
|
|
13
|
+
indices) with its conjugate to create a single node.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Graph vs Hypergraph BP
|
|
17
|
+
----------------------
|
|
18
|
+
|
|
19
|
+
- Graph (simple): the tensor network lives on a graph, where indices either
|
|
20
|
+
appear on two tensors (a bond), or appear on a single tensor (are outputs).
|
|
21
|
+
In this case, messages are exchanged directly between tensors.
|
|
22
|
+
- Hypergraph: the tensor network lives on a hypergraph, where indices can
|
|
23
|
+
appear on any number of tensors. In this case, the update procedure is two
|
|
24
|
+
parts, first all 'tensor' messages are computed, these are then used in the
|
|
25
|
+
second step to compute all the 'index' messages, which are then fed back into
|
|
26
|
+
the 'tensor' message update and so forth. For 2-norm BP one likely needs to
|
|
27
|
+
specify which indices are outputs and should be traced over.
|
|
28
|
+
|
|
29
|
+
The hypergraph case of course includes the graph case, but since the 'index'
|
|
30
|
+
message update is simply the identity, it is convenient to have a separate
|
|
31
|
+
simpler implementation, where the standard TN bond vs physical index
|
|
32
|
+
definitions hold.
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
Dense vs Vectorized vs Lazy BP
|
|
36
|
+
------------------------------
|
|
37
|
+
|
|
38
|
+
- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all
|
|
39
|
+
multibonds have been fused, then each message is a vector (1-norm case) or
|
|
40
|
+
matrix (2-norm case).
|
|
41
|
+
- Vectorized: the same as the above, but all matching tensor update and message
|
|
42
|
+
updates are stacked and performed simultaneously. This can be enormously more
|
|
43
|
+
efficient for large numbers of small tensors.
|
|
44
|
+
- Lazy: each node is potentially a tensor network itself with arbitrary inner
|
|
45
|
+
structure and number of bonds connecting to other nodes. The message are
|
|
46
|
+
generally tensors and each update is a lazy contraction, which is potentially
|
|
47
|
+
much cheaper / requires less memory than forming the 'dense' node for large
|
|
48
|
+
tensors.
|
|
49
|
+
|
|
50
|
+
(There is also the MPS flavor where each node has a 1D structure and the
|
|
51
|
+
messages are matrix product states, with updates involving compression.)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
Overall that gives 12 possible BP flavors, some implemented here:
|
|
55
|
+
|
|
56
|
+
- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm
|
|
57
|
+
- [x] (HD2BP) hyper, dense, 2-norm
|
|
58
|
+
- [x] (HV1BP) hyper, vectorized, 1-norm
|
|
59
|
+
- [ ] (HV2BP) hyper, vectorized, 2-norm
|
|
60
|
+
- [ ] (HL1BP) hyper, lazy, 1-norm
|
|
61
|
+
- [ ] (HL2BP) hyper, lazy, 2-norm
|
|
62
|
+
- [x] (D1BP) simple, dense, 1-norm - simple BP for simple tensor networks
|
|
63
|
+
- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm
|
|
64
|
+
- [ ] (V1BP) simple, vectorized, 1-norm
|
|
65
|
+
- [ ] (V2BP) simple, vectorized, 2-norm
|
|
66
|
+
- [x] (L1BP) simple, lazy, 1-norm
|
|
67
|
+
- [x] (L2BP) simple, lazy, 2-norm
|
|
68
|
+
|
|
69
|
+
The 2-norm methods can be used to compress bonds or estimate the 2-norm.
|
|
70
|
+
The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value.
|
|
71
|
+
Both methods can be used to compute index marginals and thus perform sampling.
|
|
72
|
+
|
|
73
|
+
The vectorized methods can be extremely fast for large numbers of small
|
|
74
|
+
tensors, but do currently require all dimensions to match.
|
|
75
|
+
|
|
76
|
+
The dense and lazy methods can can converge messages *locally*, i.e. only
|
|
77
|
+
update messages adjacent to messages which have changed.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
from .bp_common import initialize_hyper_messages
|
|
81
|
+
from .d1bp import D1BP, contract_d1bp
|
|
82
|
+
from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp
|
|
83
|
+
from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp
|
|
84
|
+
from .hv1bp import HV1BP, contract_hv1bp, sample_hv1bp
|
|
85
|
+
from .l1bp import L1BP, contract_l1bp
|
|
86
|
+
from .l2bp import L2BP, compress_l2bp, contract_l2bp
|
|
87
|
+
from .regions import RegionGraph
|
|
88
|
+
|
|
89
|
+
__all__ = (
|
|
90
|
+
"compress_d2bp",
|
|
91
|
+
"compress_l2bp",
|
|
92
|
+
"contract_d1bp",
|
|
93
|
+
"contract_d2bp",
|
|
94
|
+
"contract_hd1bp",
|
|
95
|
+
"contract_hv1bp",
|
|
96
|
+
"contract_l1bp",
|
|
97
|
+
"contract_l2bp",
|
|
98
|
+
"D1BP",
|
|
99
|
+
"D2BP",
|
|
100
|
+
"HD1BP",
|
|
101
|
+
"HV1BP",
|
|
102
|
+
"initialize_hyper_messages",
|
|
103
|
+
"L1BP",
|
|
104
|
+
"L2BP",
|
|
105
|
+
"RegionGraph",
|
|
106
|
+
"sample_d2bp",
|
|
107
|
+
"sample_hd1bp",
|
|
108
|
+
"sample_hv1bp",
|
|
109
|
+
)
|
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import operator
|
|
3
|
+
|
|
4
|
+
import autoray as ar
|
|
5
|
+
|
|
6
|
+
import quimb.tensor as qtn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def prod(xs):
|
|
10
|
+
"""Product of all elements in ``xs``."""
|
|
11
|
+
return functools.reduce(operator.mul, xs)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RollingDiffMean:
|
|
15
|
+
"""Tracker for the absolute rolling mean of diffs between values, to
|
|
16
|
+
assess effective convergence of BP above actual message tolerance.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, size=16):
|
|
20
|
+
self.size = size
|
|
21
|
+
self.diffs = []
|
|
22
|
+
self.last_x = None
|
|
23
|
+
self.dxsum = 0.0
|
|
24
|
+
|
|
25
|
+
def update(self, x):
|
|
26
|
+
if self.last_x is not None:
|
|
27
|
+
dx = x - self.last_x
|
|
28
|
+
self.diffs.append(dx)
|
|
29
|
+
self.dxsum += dx / self.size
|
|
30
|
+
if len(self.diffs) > self.size:
|
|
31
|
+
dx = self.diffs.pop(0)
|
|
32
|
+
self.dxsum -= dx / self.size
|
|
33
|
+
self.last_x = x
|
|
34
|
+
|
|
35
|
+
def absmeandiff(self):
|
|
36
|
+
if len(self.diffs) < self.size:
|
|
37
|
+
return float("inf")
|
|
38
|
+
return abs(self.dxsum)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BeliefPropagationCommon:
|
|
42
|
+
"""Common interfaces for belief propagation algorithms.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
max_iterations : int, optional
|
|
47
|
+
The maximum number of iterations to perform.
|
|
48
|
+
tol : float, optional
|
|
49
|
+
The convergence tolerance for messages.
|
|
50
|
+
progbar : bool, optional
|
|
51
|
+
Whether to show a progress bar.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False):
|
|
55
|
+
if progbar:
|
|
56
|
+
import tqdm
|
|
57
|
+
|
|
58
|
+
pbar = tqdm.tqdm()
|
|
59
|
+
else:
|
|
60
|
+
pbar = None
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
it = 0
|
|
64
|
+
rdm = RollingDiffMean()
|
|
65
|
+
self.converged = False
|
|
66
|
+
while not self.converged and it < max_iterations:
|
|
67
|
+
# perform a single iteration of BP
|
|
68
|
+
# we supply tol here for use with local convergence
|
|
69
|
+
nconv, ncheck, max_mdiff = self.iterate(tol=tol)
|
|
70
|
+
it += 1
|
|
71
|
+
|
|
72
|
+
# check rolling mean convergence
|
|
73
|
+
rdm.update(max_mdiff)
|
|
74
|
+
self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol)
|
|
75
|
+
|
|
76
|
+
if pbar is not None:
|
|
77
|
+
pbar.set_description(
|
|
78
|
+
f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}",
|
|
79
|
+
refresh=False,
|
|
80
|
+
)
|
|
81
|
+
pbar.update()
|
|
82
|
+
|
|
83
|
+
finally:
|
|
84
|
+
if pbar is not None:
|
|
85
|
+
pbar.close()
|
|
86
|
+
|
|
87
|
+
if tol != 0.0 and not self.converged:
|
|
88
|
+
import warnings
|
|
89
|
+
|
|
90
|
+
warnings.warn(
|
|
91
|
+
f"Belief propagation did not converge after {max_iterations} "
|
|
92
|
+
f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if info is not None:
|
|
96
|
+
info["converged"] = self.converged
|
|
97
|
+
info["iterations"] = it
|
|
98
|
+
info["max_mdiff"] = max_mdiff
|
|
99
|
+
info["rolling_abs_mean_diff"] = rdm.absmeandiff()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def initialize_hyper_messages(
|
|
103
|
+
tn,
|
|
104
|
+
fill_fn=None,
|
|
105
|
+
smudge_factor=1e-12,
|
|
106
|
+
):
|
|
107
|
+
"""Initialize messages for belief propagation, this is equivalent to doing
|
|
108
|
+
a single round of belief propagation with uniform messages.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
tn : TensorNetwork
|
|
113
|
+
The tensor network to initialize messages for.
|
|
114
|
+
fill_fn : callable, optional
|
|
115
|
+
A function to fill the messages with, of signature ``fill_fn(shape)``.
|
|
116
|
+
smudge_factor : float, optional
|
|
117
|
+
A small number to add to the messages to avoid numerical issues.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
messages : dict
|
|
122
|
+
The initial messages. For every index and tensor id pair, there will
|
|
123
|
+
be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
|
|
124
|
+
"""
|
|
125
|
+
from quimb.tensor.contraction import array_contract
|
|
126
|
+
|
|
127
|
+
backend = ar.infer_backend(next(t.data for t in tn))
|
|
128
|
+
_sum = ar.get_lib_fn(backend, "sum")
|
|
129
|
+
|
|
130
|
+
messages = {}
|
|
131
|
+
|
|
132
|
+
# compute first messages from tensors to indices
|
|
133
|
+
for tid, t in tn.tensor_map.items():
|
|
134
|
+
k_inputs = tuple(range(t.ndim))
|
|
135
|
+
for i, ix in enumerate(t.inds):
|
|
136
|
+
if fill_fn is None:
|
|
137
|
+
# sum over all other indices to get initial message
|
|
138
|
+
m = array_contract(
|
|
139
|
+
arrays=(t.data,),
|
|
140
|
+
inputs=(k_inputs,),
|
|
141
|
+
output=(i,),
|
|
142
|
+
)
|
|
143
|
+
# normalize and insert
|
|
144
|
+
messages[tid, ix] = m / _sum(m)
|
|
145
|
+
else:
|
|
146
|
+
d = t.ind_size(ix)
|
|
147
|
+
m = fill_fn((d,))
|
|
148
|
+
messages[tid, ix] = m / _sum(m)
|
|
149
|
+
|
|
150
|
+
# compute first messages from indices to tensors
|
|
151
|
+
for ix, tids in tn.ind_map.items():
|
|
152
|
+
ms = [messages[tid, ix] for tid in tids]
|
|
153
|
+
mp = prod(ms)
|
|
154
|
+
for mi, tid in zip(ms, tids):
|
|
155
|
+
m = mp / (mi + smudge_factor)
|
|
156
|
+
# normalize and insert
|
|
157
|
+
messages[ix, tid] = m / _sum(m)
|
|
158
|
+
|
|
159
|
+
return messages
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def combine_local_contractions(
|
|
163
|
+
tvals,
|
|
164
|
+
mvals,
|
|
165
|
+
backend,
|
|
166
|
+
strip_exponent=False,
|
|
167
|
+
check_for_zero=True,
|
|
168
|
+
):
|
|
169
|
+
_abs = ar.get_lib_fn(backend, "abs")
|
|
170
|
+
_log10 = ar.get_lib_fn(backend, "log10")
|
|
171
|
+
|
|
172
|
+
mantissa = 1
|
|
173
|
+
exponent = 0
|
|
174
|
+
for vt in tvals:
|
|
175
|
+
avt = _abs(vt)
|
|
176
|
+
|
|
177
|
+
if check_for_zero and (avt == 0.0):
|
|
178
|
+
if strip_exponent:
|
|
179
|
+
return 0.0, 0.0
|
|
180
|
+
else:
|
|
181
|
+
return 0.0
|
|
182
|
+
|
|
183
|
+
mantissa = mantissa * (vt / avt)
|
|
184
|
+
exponent = exponent + _log10(avt)
|
|
185
|
+
for mt in mvals:
|
|
186
|
+
amt = _abs(mt)
|
|
187
|
+
mantissa = mantissa / (mt / amt)
|
|
188
|
+
exponent = exponent - _log10(amt)
|
|
189
|
+
|
|
190
|
+
if strip_exponent:
|
|
191
|
+
return mantissa, exponent
|
|
192
|
+
else:
|
|
193
|
+
return mantissa * 10**exponent
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def contract_hyper_messages(
|
|
197
|
+
tn,
|
|
198
|
+
messages,
|
|
199
|
+
strip_exponent=False,
|
|
200
|
+
backend=None,
|
|
201
|
+
):
|
|
202
|
+
"""Estimate the contraction of ``tn`` given ``messages``, via the
|
|
203
|
+
exponential of the Bethe free entropy.
|
|
204
|
+
"""
|
|
205
|
+
tvals = []
|
|
206
|
+
mvals = []
|
|
207
|
+
|
|
208
|
+
for tid, t in tn.tensor_map.items():
|
|
209
|
+
if backend is None:
|
|
210
|
+
backend = ar.infer_backend(t.data)
|
|
211
|
+
|
|
212
|
+
arrays = [t.data]
|
|
213
|
+
inputs = [range(t.ndim)]
|
|
214
|
+
for i, ix in enumerate(t.inds):
|
|
215
|
+
arrays.append(messages[ix, tid])
|
|
216
|
+
inputs.append((i,))
|
|
217
|
+
|
|
218
|
+
# local message overlap correction
|
|
219
|
+
mvals.append(
|
|
220
|
+
qtn.array_contract(
|
|
221
|
+
(messages[tid, ix], messages[ix, tid]),
|
|
222
|
+
inputs=((0,), (0,)),
|
|
223
|
+
output=(),
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# local factor free entropy
|
|
228
|
+
tvals.append(qtn.array_contract(arrays, inputs, output=()))
|
|
229
|
+
|
|
230
|
+
for ix, tids in tn.ind_map.items():
|
|
231
|
+
arrays = tuple(messages[tid, ix] for tid in tids)
|
|
232
|
+
inputs = tuple((0,) for _ in tids)
|
|
233
|
+
# local variable free entropy
|
|
234
|
+
tvals.append(qtn.array_contract(arrays, inputs, output=()))
|
|
235
|
+
|
|
236
|
+
return combine_local_contractions(
|
|
237
|
+
tvals, mvals, backend, strip_exponent=strip_exponent
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def compute_index_marginal(tn, ind, messages):
|
|
242
|
+
"""Compute the marginal for a single index given ``messages``.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
tn : TensorNetwork
|
|
247
|
+
The tensor network to compute the marginal for.
|
|
248
|
+
ind : int
|
|
249
|
+
The index to compute the marginal for.
|
|
250
|
+
messages : dict
|
|
251
|
+
The messages to use, which should match ``tn``.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
marginal : array_like
|
|
256
|
+
The marginal probability distribution for the index ``ind``.
|
|
257
|
+
"""
|
|
258
|
+
m = prod(messages[tid, ind] for tid in tn.ind_map[ind])
|
|
259
|
+
return m / ar.do("sum", m)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def compute_tensor_marginal(tn, tid, messages):
|
|
263
|
+
"""Compute the marginal for the region surrounding a single tensor/factor
|
|
264
|
+
given ``messages``.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
tn : TensorNetwork
|
|
269
|
+
The tensor network to compute the marginal for.
|
|
270
|
+
tid : int
|
|
271
|
+
The tensor id to compute the marginal for.
|
|
272
|
+
messages : dict
|
|
273
|
+
The messages to use, which should match ``tn``.
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
marginal : array_like
|
|
278
|
+
The marginal probability distribution for the tensor/factor ``tid``.
|
|
279
|
+
"""
|
|
280
|
+
t = tn.tensor_map[tid]
|
|
281
|
+
|
|
282
|
+
output = tuple(range(t.ndim))
|
|
283
|
+
inputs = [output]
|
|
284
|
+
arrays = [t.data]
|
|
285
|
+
|
|
286
|
+
for i, ix in enumerate(t.inds):
|
|
287
|
+
mix = prod(
|
|
288
|
+
messages[otid, ix] for otid in tn.ind_map[ix] if otid != tid
|
|
289
|
+
)
|
|
290
|
+
inputs.append((i,))
|
|
291
|
+
arrays.append(mix)
|
|
292
|
+
|
|
293
|
+
m = qtn.array_contract(
|
|
294
|
+
arrays=arrays,
|
|
295
|
+
inputs=inputs,
|
|
296
|
+
output=output,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return m / ar.do("sum", m)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def compute_all_index_marginals_from_messages(tn, messages):
|
|
303
|
+
"""Compute all index marginals from belief propagation messages.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
tn : TensorNetwork
|
|
308
|
+
The tensor network to compute marginals for.
|
|
309
|
+
messages : dict
|
|
310
|
+
The belief propagation messages.
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
marginals : dict
|
|
315
|
+
The marginals for each index.
|
|
316
|
+
"""
|
|
317
|
+
return {ix: compute_index_marginal(tn, ix, messages) for ix in tn.ind_map}
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def maybe_get_thread_pool(thread_pool):
|
|
321
|
+
"""Get a thread pool if requested."""
|
|
322
|
+
if thread_pool is False:
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
if thread_pool is True:
|
|
326
|
+
import quimb as qu
|
|
327
|
+
|
|
328
|
+
return qu.get_thread_pool()
|
|
329
|
+
|
|
330
|
+
if isinstance(thread_pool, int):
|
|
331
|
+
import quimb as qu
|
|
332
|
+
|
|
333
|
+
return qu.get_thread_pool(thread_pool)
|
|
334
|
+
|
|
335
|
+
return thread_pool
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True):
|
|
339
|
+
"""For lazy BP algorithms, create the data structures describing the
|
|
340
|
+
effective graph of the lazily grouped 'sites' given by ``site_tags``.
|
|
341
|
+
"""
|
|
342
|
+
if site_tags is None:
|
|
343
|
+
site_tags = set(tn.site_tags)
|
|
344
|
+
else:
|
|
345
|
+
site_tags = set(site_tags)
|
|
346
|
+
|
|
347
|
+
edges = {}
|
|
348
|
+
neighbors = {}
|
|
349
|
+
local_tns = {}
|
|
350
|
+
touch_map = {}
|
|
351
|
+
|
|
352
|
+
for ix in tn.ind_map:
|
|
353
|
+
ts = tn._inds_get(ix)
|
|
354
|
+
tags = {tag for t in ts for tag in t.tags if tag in site_tags}
|
|
355
|
+
if len(tags) >= 2:
|
|
356
|
+
i, j = tuple(sorted(tags))
|
|
357
|
+
|
|
358
|
+
if (i, j) in edges:
|
|
359
|
+
# already processed this edge
|
|
360
|
+
continue
|
|
361
|
+
|
|
362
|
+
# add to neighbor map
|
|
363
|
+
neighbors.setdefault(i, []).append(j)
|
|
364
|
+
neighbors.setdefault(j, []).append(i)
|
|
365
|
+
|
|
366
|
+
# get local TNs and compute bonds between them,
|
|
367
|
+
# rank simplify here also to prepare for contractions
|
|
368
|
+
try:
|
|
369
|
+
tn_i = local_tns[i]
|
|
370
|
+
except KeyError:
|
|
371
|
+
tn_i = local_tns[i] = tn.select(i, virtual=False)
|
|
372
|
+
if rank_simplify:
|
|
373
|
+
tn_i.rank_simplify_()
|
|
374
|
+
try:
|
|
375
|
+
tn_j = local_tns[j]
|
|
376
|
+
except KeyError:
|
|
377
|
+
tn_j = local_tns[j] = tn.select(j, virtual=False)
|
|
378
|
+
if rank_simplify:
|
|
379
|
+
tn_j.rank_simplify_()
|
|
380
|
+
|
|
381
|
+
edges[i, j] = tuple(qtn.bonds(tn_i, tn_j))
|
|
382
|
+
|
|
383
|
+
for i, j in edges:
|
|
384
|
+
touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i)
|
|
385
|
+
touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j)
|
|
386
|
+
|
|
387
|
+
if len(local_tns) != len(site_tags):
|
|
388
|
+
# handle potentially disconnected sites
|
|
389
|
+
for i in sorted(site_tags):
|
|
390
|
+
try:
|
|
391
|
+
tn_i = local_tns[i] = tn.select(i, virtual=False)
|
|
392
|
+
if rank_simplify:
|
|
393
|
+
tn_i.rank_simplify_()
|
|
394
|
+
except KeyError:
|
|
395
|
+
pass
|
|
396
|
+
|
|
397
|
+
return edges, neighbors, local_tns, touch_map
|