Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +0 -3
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
- trajectree-0.0.2.dist-info/RECORD +16 -0
- trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
- trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
- trajectree/quimb/docs/conf.py +0 -158
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
- trajectree/quimb/quimb/__init__.py +0 -507
- trajectree/quimb/quimb/calc.py +0 -1491
- trajectree/quimb/quimb/core.py +0 -2279
- trajectree/quimb/quimb/evo.py +0 -712
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +0 -129
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
- trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
- trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
- trajectree/quimb/quimb/experimental/schematic.py +0 -7
- trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
- trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
- trajectree/quimb/quimb/gates.py +0 -36
- trajectree/quimb/quimb/gen/__init__.py +0 -2
- trajectree/quimb/quimb/gen/operators.py +0 -1167
- trajectree/quimb/quimb/gen/rand.py +0 -713
- trajectree/quimb/quimb/gen/states.py +0 -479
- trajectree/quimb/quimb/linalg/__init__.py +0 -6
- trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
- trajectree/quimb/quimb/linalg/autoblock.py +0 -258
- trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
- trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
- trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
- trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
- trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
- trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
- trajectree/quimb/quimb/schematic.py +0 -1518
- trajectree/quimb/quimb/tensor/__init__.py +0 -401
- trajectree/quimb/quimb/tensor/array_ops.py +0 -610
- trajectree/quimb/quimb/tensor/circuit.py +0 -4824
- trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
- trajectree/quimb/quimb/tensor/contraction.py +0 -336
- trajectree/quimb/quimb/tensor/decomp.py +0 -1255
- trajectree/quimb/quimb/tensor/drawing.py +0 -1646
- trajectree/quimb/quimb/tensor/fitting.py +0 -385
- trajectree/quimb/quimb/tensor/geometry.py +0 -583
- trajectree/quimb/quimb/tensor/interface.py +0 -114
- trajectree/quimb/quimb/tensor/networking.py +0 -1058
- trajectree/quimb/quimb/tensor/optimize.py +0 -1818
- trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
- trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
- trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
- trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
- trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
- trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
- trajectree/quimb/quimb/utils.py +0 -892
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +0 -501
- trajectree/quimb/tests/test_calc.py +0 -788
- trajectree/quimb/tests/test_core.py +0 -847
- trajectree/quimb/tests/test_evo.py +0 -565
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +0 -361
- trajectree/quimb/tests/test_gen/test_rand.py +0 -296
- trajectree/quimb/tests/test_gen/test_states.py +0 -261
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
- trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
- trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
- trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
- trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
- trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
- trajectree/quimb/tests/test_utils.py +0 -85
- trajectree-0.0.1.dist-info/RECORD +0 -126
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,336 +0,0 @@
|
|
|
1
|
-
"""Functions relating to tensor network contraction.
|
|
2
|
-
"""
|
|
3
|
-
import functools
|
|
4
|
-
import itertools
|
|
5
|
-
import threading
|
|
6
|
-
import contextlib
|
|
7
|
-
import collections
|
|
8
|
-
|
|
9
|
-
import cotengra as ctg
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
_CONTRACT_STRATEGY = 'greedy'
|
|
13
|
-
_TEMP_CONTRACT_STRATEGIES = collections.defaultdict(list)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def get_contract_strategy():
|
|
17
|
-
r"""Get the default contraction strategy - the option supplied as
|
|
18
|
-
``optimize`` to ``cotengra``.
|
|
19
|
-
"""
|
|
20
|
-
if not _TEMP_CONTRACT_STRATEGIES:
|
|
21
|
-
# shortcut for when no temp strategies are in use
|
|
22
|
-
return _CONTRACT_STRATEGY
|
|
23
|
-
|
|
24
|
-
thread_id = threading.get_ident()
|
|
25
|
-
if thread_id not in _TEMP_CONTRACT_STRATEGIES:
|
|
26
|
-
return _CONTRACT_STRATEGY
|
|
27
|
-
|
|
28
|
-
temp_strategies = _TEMP_CONTRACT_STRATEGIES[thread_id]
|
|
29
|
-
# empty list -> not in context manager -> use default strategy
|
|
30
|
-
if not temp_strategies:
|
|
31
|
-
# clean up to allow above shortcuts
|
|
32
|
-
del _TEMP_CONTRACT_STRATEGIES[thread_id]
|
|
33
|
-
return _CONTRACT_STRATEGY
|
|
34
|
-
|
|
35
|
-
# use most recently set strategy for this threy
|
|
36
|
-
return temp_strategies[-1]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def set_contract_strategy(strategy):
|
|
40
|
-
"""Get the default contraction strategy - the option supplied as
|
|
41
|
-
``optimize`` to ``cotengra``.
|
|
42
|
-
"""
|
|
43
|
-
global _CONTRACT_STRATEGY
|
|
44
|
-
_CONTRACT_STRATEGY = strategy
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@contextlib.contextmanager
|
|
48
|
-
def contract_strategy(strategy, set_globally=False):
|
|
49
|
-
"""A context manager to temporarily set the default contraction strategy
|
|
50
|
-
supplied as ``optimize`` to ``cotengra``. By default, this only sets the
|
|
51
|
-
contract strategy for the current thread.
|
|
52
|
-
|
|
53
|
-
Parameters
|
|
54
|
-
----------
|
|
55
|
-
set_globally : bool, optimize
|
|
56
|
-
Whether to set the strategy just for this thread, or for all threads.
|
|
57
|
-
If you are entering the context, *then* using multithreading, you might
|
|
58
|
-
want ``True``.
|
|
59
|
-
"""
|
|
60
|
-
if set_globally:
|
|
61
|
-
orig_strategy = get_contract_strategy()
|
|
62
|
-
set_contract_strategy(strategy)
|
|
63
|
-
try:
|
|
64
|
-
yield
|
|
65
|
-
finally:
|
|
66
|
-
set_contract_strategy(orig_strategy)
|
|
67
|
-
else:
|
|
68
|
-
thread_id = threading.get_ident()
|
|
69
|
-
temp_strategies = _TEMP_CONTRACT_STRATEGIES[thread_id]
|
|
70
|
-
temp_strategies.append(strategy)
|
|
71
|
-
try:
|
|
72
|
-
yield
|
|
73
|
-
finally:
|
|
74
|
-
temp_strategies.pop()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
get_symbol = ctg.get_symbol
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def empty_symbol_map():
|
|
81
|
-
"""Get a default dictionary that will populate with symbol entries as they
|
|
82
|
-
are accessed.
|
|
83
|
-
"""
|
|
84
|
-
return collections.defaultdict(map(get_symbol, itertools.count()).__next__)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def inds_to_symbols(inputs):
|
|
88
|
-
"""Map a sequence of inputs terms, containing any hashable indices, to
|
|
89
|
-
single unicode letters, appropriate for einsum.
|
|
90
|
-
|
|
91
|
-
Parameters
|
|
92
|
-
----------
|
|
93
|
-
inputs : sequence of sequence of hashable
|
|
94
|
-
The input indices per tensor.
|
|
95
|
-
|
|
96
|
-
Returns
|
|
97
|
-
-------
|
|
98
|
-
symbols : dict[hashable, str]
|
|
99
|
-
The mapping from index to symbol.
|
|
100
|
-
"""
|
|
101
|
-
return ctg.get_symbol_map(inputs)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@functools.lru_cache(2**12)
|
|
105
|
-
def inds_to_eq(inputs, output=None):
|
|
106
|
-
"""Turn input and output indices of any sort into a single 'equation'
|
|
107
|
-
string where each index is a single 'symbol' (unicode character).
|
|
108
|
-
|
|
109
|
-
Parameters
|
|
110
|
-
----------
|
|
111
|
-
inputs : sequence of sequence of hashable
|
|
112
|
-
The input indices per tensor.
|
|
113
|
-
output : sequence of hashable
|
|
114
|
-
The output indices.
|
|
115
|
-
|
|
116
|
-
Returns
|
|
117
|
-
-------
|
|
118
|
-
eq : str
|
|
119
|
-
The string to feed to einsum/contract.
|
|
120
|
-
"""
|
|
121
|
-
symbols = empty_symbol_map()
|
|
122
|
-
in_str = ("".join(symbols[ix] for ix in inds) for inds in inputs)
|
|
123
|
-
in_str = ",".join(in_str)
|
|
124
|
-
if output is None:
|
|
125
|
-
out_str = "".join(
|
|
126
|
-
ix for ix in symbols.values() if in_str.count(ix) == 1
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
out_str = "".join(symbols[ix] for ix in output)
|
|
130
|
-
return f"{in_str}->{out_str}"
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
_CONTRACT_BACKEND = None
|
|
134
|
-
_TENSOR_LINOP_BACKEND = None
|
|
135
|
-
_TEMP_CONTRACT_BACKENDS = collections.defaultdict(list)
|
|
136
|
-
_TEMP_TENSOR_LINOP_BACKENDS = collections.defaultdict(list)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def get_contract_backend():
|
|
140
|
-
"""Get the default backend used for tensor contractions, via 'cotengra'.
|
|
141
|
-
|
|
142
|
-
See Also
|
|
143
|
-
--------
|
|
144
|
-
set_contract_backend, get_tensor_linop_backend, set_tensor_linop_backend,
|
|
145
|
-
tensor_contract
|
|
146
|
-
"""
|
|
147
|
-
if not _TEMP_CONTRACT_BACKENDS:
|
|
148
|
-
return _CONTRACT_BACKEND
|
|
149
|
-
|
|
150
|
-
thread_id = threading.get_ident()
|
|
151
|
-
if thread_id not in _TEMP_CONTRACT_BACKENDS:
|
|
152
|
-
return _CONTRACT_BACKEND
|
|
153
|
-
|
|
154
|
-
temp_backends = _TEMP_CONTRACT_BACKENDS[thread_id]
|
|
155
|
-
if not temp_backends:
|
|
156
|
-
del _TEMP_CONTRACT_BACKENDS[thread_id]
|
|
157
|
-
return _CONTRACT_BACKEND
|
|
158
|
-
|
|
159
|
-
return temp_backends[-1]
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def set_contract_backend(backend):
|
|
163
|
-
"""Set the default backend used for tensor contractions, via 'cotengra'.
|
|
164
|
-
|
|
165
|
-
See Also
|
|
166
|
-
--------
|
|
167
|
-
get_contract_backend, set_tensor_linop_backend, get_tensor_linop_backend,
|
|
168
|
-
tensor_contract
|
|
169
|
-
"""
|
|
170
|
-
global _CONTRACT_BACKEND
|
|
171
|
-
_CONTRACT_BACKEND = backend
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
@contextlib.contextmanager
|
|
175
|
-
def contract_backend(backend, set_globally=False):
|
|
176
|
-
"""A context manager to temporarily set the default backend used for tensor
|
|
177
|
-
contractions, via 'cotengra'. By default, this only sets the contract
|
|
178
|
-
backend for the current thread.
|
|
179
|
-
|
|
180
|
-
Parameters
|
|
181
|
-
----------
|
|
182
|
-
set_globally : bool, optimize
|
|
183
|
-
Whether to set the backend just for this thread, or for all threads. If
|
|
184
|
-
you are entering the context, *then* using multithreading, you might
|
|
185
|
-
want ``True``.
|
|
186
|
-
"""
|
|
187
|
-
if set_globally:
|
|
188
|
-
orig_backend = get_contract_backend()
|
|
189
|
-
set_contract_backend(backend)
|
|
190
|
-
try:
|
|
191
|
-
yield
|
|
192
|
-
finally:
|
|
193
|
-
set_contract_backend(orig_backend)
|
|
194
|
-
else:
|
|
195
|
-
thread_id = threading.get_ident()
|
|
196
|
-
temp_backends = _TEMP_CONTRACT_BACKENDS[thread_id]
|
|
197
|
-
temp_backends.append(backend)
|
|
198
|
-
try:
|
|
199
|
-
yield
|
|
200
|
-
finally:
|
|
201
|
-
temp_backends.pop()
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def get_tensor_linop_backend():
|
|
205
|
-
"""Get the default backend used for tensor network linear operators, via
|
|
206
|
-
'cotengra'. This is different from the default contraction backend as
|
|
207
|
-
the contractions are likely repeatedly called many times.
|
|
208
|
-
|
|
209
|
-
See Also
|
|
210
|
-
--------
|
|
211
|
-
set_tensor_linop_backend, set_contract_backend, get_contract_backend,
|
|
212
|
-
TNLinearOperator
|
|
213
|
-
"""
|
|
214
|
-
if not _TEMP_TENSOR_LINOP_BACKENDS:
|
|
215
|
-
return _TENSOR_LINOP_BACKEND
|
|
216
|
-
|
|
217
|
-
thread_id = threading.get_ident()
|
|
218
|
-
if thread_id not in _TEMP_TENSOR_LINOP_BACKENDS:
|
|
219
|
-
return _TENSOR_LINOP_BACKEND
|
|
220
|
-
|
|
221
|
-
temp_backends = _TEMP_TENSOR_LINOP_BACKENDS[thread_id]
|
|
222
|
-
if not temp_backends:
|
|
223
|
-
del _TEMP_TENSOR_LINOP_BACKENDS[thread_id]
|
|
224
|
-
return _TENSOR_LINOP_BACKEND
|
|
225
|
-
|
|
226
|
-
return temp_backends[-1]
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def set_tensor_linop_backend(backend):
|
|
230
|
-
"""Set the default backend used for tensor network linear operators, via
|
|
231
|
-
'cotengra'. This is different from the default contraction backend as
|
|
232
|
-
the contractions are likely repeatedly called many times.
|
|
233
|
-
|
|
234
|
-
See Also
|
|
235
|
-
--------
|
|
236
|
-
get_tensor_linop_backend, set_contract_backend, get_contract_backend,
|
|
237
|
-
TNLinearOperator
|
|
238
|
-
"""
|
|
239
|
-
global _TENSOR_LINOP_BACKEND
|
|
240
|
-
_TENSOR_LINOP_BACKEND = backend
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
@contextlib.contextmanager
|
|
244
|
-
def tensor_linop_backend(backend, set_globally=False):
|
|
245
|
-
"""A context manager to temporarily set the default backend used for tensor
|
|
246
|
-
network linear operators, via 'cotengra'. By default, this
|
|
247
|
-
only sets the contract backend for the current thread.
|
|
248
|
-
|
|
249
|
-
Parameters
|
|
250
|
-
----------
|
|
251
|
-
set_globally : bool, optimize
|
|
252
|
-
Whether to set the backend just for this thread, or for all threads. If
|
|
253
|
-
you are entering the context, *then* using multithreading, you might
|
|
254
|
-
want ``True``.
|
|
255
|
-
"""
|
|
256
|
-
if set_globally:
|
|
257
|
-
orig_backend = get_tensor_linop_backend()
|
|
258
|
-
set_tensor_linop_backend(backend)
|
|
259
|
-
try:
|
|
260
|
-
yield
|
|
261
|
-
finally:
|
|
262
|
-
set_tensor_linop_backend(orig_backend)
|
|
263
|
-
else:
|
|
264
|
-
thread_id = threading.get_ident()
|
|
265
|
-
temp_backends = _TEMP_TENSOR_LINOP_BACKENDS[thread_id]
|
|
266
|
-
temp_backends.append(backend)
|
|
267
|
-
try:
|
|
268
|
-
yield
|
|
269
|
-
finally:
|
|
270
|
-
temp_backends.pop()
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
@functools.wraps(ctg.array_contract)
|
|
274
|
-
def array_contract(
|
|
275
|
-
arrays,
|
|
276
|
-
inputs,
|
|
277
|
-
output=None,
|
|
278
|
-
optimize=None,
|
|
279
|
-
backend=None,
|
|
280
|
-
**kwargs,
|
|
281
|
-
):
|
|
282
|
-
if optimize is None:
|
|
283
|
-
optimize = get_contract_strategy()
|
|
284
|
-
if backend is None:
|
|
285
|
-
backend = get_contract_backend()
|
|
286
|
-
return ctg.array_contract(
|
|
287
|
-
arrays, inputs, output, optimize=optimize, backend=backend, **kwargs
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
@functools.wraps(ctg.array_contract_expression)
|
|
292
|
-
def array_contract_expression(*args, optimize=None, **kwargs):
|
|
293
|
-
if optimize is None:
|
|
294
|
-
optimize = get_contract_strategy()
|
|
295
|
-
return ctg.array_contract_expression(*args, optimize=optimize, **kwargs)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
@functools.wraps(ctg.array_contract_tree)
|
|
299
|
-
def array_contract_tree(*args, optimize=None, **kwargs):
|
|
300
|
-
if optimize is None:
|
|
301
|
-
optimize = get_contract_strategy()
|
|
302
|
-
return ctg.array_contract_tree(*args, optimize=optimize, **kwargs)
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
@functools.wraps(ctg.array_contract_path)
|
|
306
|
-
def array_contract_path(*args, optimize=None, **kwargs):
|
|
307
|
-
if optimize is None:
|
|
308
|
-
optimize = get_contract_strategy()
|
|
309
|
-
return ctg.array_contract_path(*args, optimize=optimize, **kwargs)
|
|
310
|
-
|
|
311
|
-
def array_contract_pathinfo(*args, **kwargs):
|
|
312
|
-
|
|
313
|
-
import opt_einsum as oe
|
|
314
|
-
|
|
315
|
-
tree = array_contract_tree(*args, **kwargs)
|
|
316
|
-
|
|
317
|
-
if tree.sliced_inds:
|
|
318
|
-
import warnings
|
|
319
|
-
|
|
320
|
-
warnings.warn(
|
|
321
|
-
"The contraction tree has sliced indices, which are not "
|
|
322
|
-
"supported by opt_einsum. Ignoring them for now."
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
shapes = tree.get_shapes()
|
|
326
|
-
path = tree.get_path()
|
|
327
|
-
eq = tree.get_eq()
|
|
328
|
-
|
|
329
|
-
if (eq == "->") and (len(path) == 0):
|
|
330
|
-
# XXX: opt_einsum does not support empty paths
|
|
331
|
-
# https://github.com/jcmgray/quimb/issues/231
|
|
332
|
-
# https://github.com/dgasmith/opt_einsum/pull/229
|
|
333
|
-
path = ((0,),)
|
|
334
|
-
|
|
335
|
-
return oe.contract_path(eq, *shapes, shapes=True, optimize=path)[1]
|
|
336
|
-
|