Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +3 -0
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
- trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
- trajectree/quimb/docs/conf.py +158 -0
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
- trajectree/quimb/quimb/__init__.py +507 -0
- trajectree/quimb/quimb/calc.py +1491 -0
- trajectree/quimb/quimb/core.py +2279 -0
- trajectree/quimb/quimb/evo.py +712 -0
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +129 -0
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
- trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
- trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
- trajectree/quimb/quimb/experimental/schematic.py +7 -0
- trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
- trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
- trajectree/quimb/quimb/gates.py +36 -0
- trajectree/quimb/quimb/gen/__init__.py +2 -0
- trajectree/quimb/quimb/gen/operators.py +1167 -0
- trajectree/quimb/quimb/gen/rand.py +713 -0
- trajectree/quimb/quimb/gen/states.py +479 -0
- trajectree/quimb/quimb/linalg/__init__.py +6 -0
- trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
- trajectree/quimb/quimb/linalg/autoblock.py +258 -0
- trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
- trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
- trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
- trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
- trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
- trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
- trajectree/quimb/quimb/schematic.py +1518 -0
- trajectree/quimb/quimb/tensor/__init__.py +401 -0
- trajectree/quimb/quimb/tensor/array_ops.py +610 -0
- trajectree/quimb/quimb/tensor/circuit.py +4824 -0
- trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
- trajectree/quimb/quimb/tensor/contraction.py +336 -0
- trajectree/quimb/quimb/tensor/decomp.py +1255 -0
- trajectree/quimb/quimb/tensor/drawing.py +1646 -0
- trajectree/quimb/quimb/tensor/fitting.py +385 -0
- trajectree/quimb/quimb/tensor/geometry.py +583 -0
- trajectree/quimb/quimb/tensor/interface.py +114 -0
- trajectree/quimb/quimb/tensor/networking.py +1058 -0
- trajectree/quimb/quimb/tensor/optimize.py +1818 -0
- trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
- trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
- trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
- trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
- trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
- trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
- trajectree/quimb/quimb/utils.py +892 -0
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +501 -0
- trajectree/quimb/tests/test_calc.py +788 -0
- trajectree/quimb/tests/test_core.py +847 -0
- trajectree/quimb/tests/test_evo.py +565 -0
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +361 -0
- trajectree/quimb/tests/test_gen/test_rand.py +296 -0
- trajectree/quimb/tests/test_gen/test_states.py +261 -0
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
- trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
- trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
- trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
- trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
- trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
- trajectree/quimb/tests/test_utils.py +85 -0
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
- trajectree-0.0.1.dist-info/RECORD +126 -0
- trajectree-0.0.0.dist-info/RECORD +0 -16
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,892 @@
|
|
|
1
|
+
"""Misc utility functions."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import functools
|
|
5
|
+
import itertools
|
|
6
|
+
from importlib.util import find_spec
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import cytoolz
|
|
10
|
+
|
|
11
|
+
last = cytoolz.last
|
|
12
|
+
concat = cytoolz.concat
|
|
13
|
+
frequencies = cytoolz.frequencies
|
|
14
|
+
partition_all = cytoolz.partition_all
|
|
15
|
+
merge_with = cytoolz.merge_with
|
|
16
|
+
valmap = cytoolz.valmap
|
|
17
|
+
partitionby = cytoolz.partitionby
|
|
18
|
+
concatv = cytoolz.concatv
|
|
19
|
+
partition = cytoolz.partition
|
|
20
|
+
partition_all = cytoolz.partition_all
|
|
21
|
+
compose = cytoolz.compose
|
|
22
|
+
identity = cytoolz.identity
|
|
23
|
+
isiterable = cytoolz.isiterable
|
|
24
|
+
unique = cytoolz.unique
|
|
25
|
+
keymap = cytoolz.keymap
|
|
26
|
+
except ImportError:
|
|
27
|
+
import toolz
|
|
28
|
+
|
|
29
|
+
last = toolz.last
|
|
30
|
+
concat = toolz.concat
|
|
31
|
+
frequencies = toolz.frequencies
|
|
32
|
+
partition_all = toolz.partition_all
|
|
33
|
+
merge_with = toolz.merge_with
|
|
34
|
+
valmap = toolz.valmap
|
|
35
|
+
partition = toolz.partition
|
|
36
|
+
partitionby = toolz.partitionby
|
|
37
|
+
concatv = toolz.concatv
|
|
38
|
+
partition_all = toolz.partition_all
|
|
39
|
+
compose = toolz.compose
|
|
40
|
+
identity = toolz.identity
|
|
41
|
+
isiterable = toolz.isiterable
|
|
42
|
+
unique = toolz.unique
|
|
43
|
+
keymap = toolz.keymap
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_CHECK_OPT_MSG = "Option `{}` should be one of {}, but got '{}'."
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def check_opt(name, value, valid):
|
|
50
|
+
"""Check whether ``value`` takes one of ``valid`` options, and raise an
|
|
51
|
+
informative error if not.
|
|
52
|
+
"""
|
|
53
|
+
if value not in valid:
|
|
54
|
+
raise ValueError(_CHECK_OPT_MSG.format(name, valid, value))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def find_library(x):
|
|
58
|
+
"""Check if library is installed.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
x : str
|
|
63
|
+
Name of library
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
bool
|
|
68
|
+
If library is available.
|
|
69
|
+
"""
|
|
70
|
+
return find_spec(x) is not None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def raise_cant_find_library_function(x, extra_msg=None):
|
|
74
|
+
"""Return function to flag up a missing necessary library.
|
|
75
|
+
|
|
76
|
+
This is simplify the task of flagging optional dependencies only at the
|
|
77
|
+
point at which they are needed, and not earlier.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
x : str
|
|
82
|
+
Name of library
|
|
83
|
+
extra_msg : str, optional
|
|
84
|
+
Make the function print this message as well, for additional
|
|
85
|
+
information.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
callable
|
|
90
|
+
A mock function that when called, raises an import error specifying
|
|
91
|
+
the required library.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def function_that_will_raise(*_, **__):
|
|
95
|
+
error_msg = f"The library {x} is not installed. "
|
|
96
|
+
if extra_msg is not None:
|
|
97
|
+
error_msg += extra_msg
|
|
98
|
+
raise ImportError(error_msg)
|
|
99
|
+
|
|
100
|
+
return function_that_will_raise
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
FOUND_TQDM = find_library("tqdm")
|
|
104
|
+
if FOUND_TQDM:
|
|
105
|
+
from tqdm import tqdm
|
|
106
|
+
|
|
107
|
+
class continuous_progbar(tqdm):
|
|
108
|
+
"""A continuous version of tqdm, so that it can be updated with a float
|
|
109
|
+
within some pre-given range, rather than a number of steps.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
args : (stop) or (start, stop)
|
|
114
|
+
Stopping point (and starting point if ``len(args) == 2``) of window
|
|
115
|
+
within which to evaluate progress.
|
|
116
|
+
total : int
|
|
117
|
+
The number of steps to represent the continuous progress with.
|
|
118
|
+
kwargs
|
|
119
|
+
Supplied to ``tqdm.tqdm``
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(self, *args, total=100, **kwargs):
|
|
123
|
+
""" """
|
|
124
|
+
kwargs.setdefault("ascii", True)
|
|
125
|
+
super(continuous_progbar, self).__init__(
|
|
126
|
+
total=total, unit="%", **kwargs
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if len(args) == 2:
|
|
130
|
+
self.start, self.stop = args
|
|
131
|
+
else:
|
|
132
|
+
self.start, self.stop = 0, args[0]
|
|
133
|
+
|
|
134
|
+
self.range = self.stop - self.start
|
|
135
|
+
self.step = 1
|
|
136
|
+
|
|
137
|
+
def cupdate(self, x):
|
|
138
|
+
"""'Continuous' update of progress bar.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
x : float
|
|
143
|
+
Current position within the range ``[self.start, self.stop]``.
|
|
144
|
+
"""
|
|
145
|
+
num_update = int(
|
|
146
|
+
(self.total + 1) * (x - self.start) / self.range - self.step
|
|
147
|
+
)
|
|
148
|
+
if num_update > 0:
|
|
149
|
+
self.update(num_update)
|
|
150
|
+
self.step += num_update
|
|
151
|
+
|
|
152
|
+
def progbar(*args, **kwargs):
|
|
153
|
+
kwargs.setdefault("ascii", True)
|
|
154
|
+
return tqdm(*args, **kwargs)
|
|
155
|
+
|
|
156
|
+
else: # pragma: no cover
|
|
157
|
+
extra_msg = "This is needed to show progress bars."
|
|
158
|
+
progbar = raise_cant_find_library_function("tqdm", extra_msg)
|
|
159
|
+
continuous_progbar = raise_cant_find_library_function("tqdm", extra_msg)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def deprecated(fn, old_name, new_name):
|
|
163
|
+
"""Mark a function as deprecated, and indicate the new name."""
|
|
164
|
+
|
|
165
|
+
@functools.wraps(fn)
|
|
166
|
+
def new_fn(*args, **kwargs):
|
|
167
|
+
import warnings
|
|
168
|
+
|
|
169
|
+
warnings.warn(
|
|
170
|
+
f"The {old_name} function is deprecated in favor of {new_name}",
|
|
171
|
+
Warning,
|
|
172
|
+
)
|
|
173
|
+
return fn(*args, **kwargs)
|
|
174
|
+
|
|
175
|
+
return new_fn
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def int2tup(x):
|
|
179
|
+
return (
|
|
180
|
+
x if isinstance(x, tuple) else (x,) if isinstance(x, int) else tuple(x)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def ensure_dict(x):
|
|
185
|
+
"""Make sure ``x`` is a ``dict``, creating an empty one if ``x is None``."""
|
|
186
|
+
if x is None:
|
|
187
|
+
return {}
|
|
188
|
+
return dict(x)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def pairwise(iterable):
|
|
192
|
+
"""Iterate over each pair of neighbours in ``iterable``."""
|
|
193
|
+
a, b = itertools.tee(iterable)
|
|
194
|
+
next(b, None)
|
|
195
|
+
return zip(a, b)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def print_multi_line(*lines, max_width=None):
|
|
199
|
+
"""Print multiple lines, with a maximum width."""
|
|
200
|
+
if max_width is None:
|
|
201
|
+
import shutil
|
|
202
|
+
|
|
203
|
+
max_width, _ = shutil.get_terminal_size()
|
|
204
|
+
|
|
205
|
+
max_line_lenth = max(len(ln) for ln in lines)
|
|
206
|
+
|
|
207
|
+
if max_line_lenth <= max_width:
|
|
208
|
+
for ln in lines:
|
|
209
|
+
print(ln)
|
|
210
|
+
|
|
211
|
+
else: # pragma: no cover
|
|
212
|
+
max_width -= 10 # for ellipses and pad
|
|
213
|
+
n_lines = len(lines)
|
|
214
|
+
n_blocks = (max_line_lenth - 1) // max_width + 1
|
|
215
|
+
|
|
216
|
+
for i in range(n_blocks):
|
|
217
|
+
if i == 0:
|
|
218
|
+
for j, l in enumerate(lines):
|
|
219
|
+
print(
|
|
220
|
+
"..." if j == n_lines // 2 else " ",
|
|
221
|
+
l[i * max_width : (i + 1) * max_width],
|
|
222
|
+
"..." if j == n_lines // 2 else " ",
|
|
223
|
+
)
|
|
224
|
+
print(("{:^" + str(max_width) + "}").format("..."))
|
|
225
|
+
elif i == n_blocks - 1:
|
|
226
|
+
for ln in lines:
|
|
227
|
+
print(" ", ln[i * max_width : (i + 1) * max_width])
|
|
228
|
+
else:
|
|
229
|
+
for j, ln in enumerate(lines):
|
|
230
|
+
print(
|
|
231
|
+
"..." if j == n_lines // 2 else " ",
|
|
232
|
+
ln[i * max_width : (i + 1) * max_width],
|
|
233
|
+
"..." if j == n_lines // 2 else " ",
|
|
234
|
+
)
|
|
235
|
+
print(("{:^" + str(max_width) + "}").format("..."))
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def format_number_with_error(x, err):
|
|
239
|
+
"""Given ``x`` with error ``err``, format a string showing the relevant
|
|
240
|
+
digits of ``x`` with two significant digits of the error bracketed, and
|
|
241
|
+
overall exponent if necessary.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
x : float
|
|
246
|
+
The value to print.
|
|
247
|
+
err : float
|
|
248
|
+
The error on ``x``.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
str
|
|
253
|
+
|
|
254
|
+
Examples
|
|
255
|
+
--------
|
|
256
|
+
|
|
257
|
+
>>> print_number_with_uncertainty(0.1542412, 0.0626653)
|
|
258
|
+
'0.154(63)'
|
|
259
|
+
|
|
260
|
+
>>> print_number_with_uncertainty(-128124123097, 6424)
|
|
261
|
+
'-1.281241231(64)e+11'
|
|
262
|
+
|
|
263
|
+
"""
|
|
264
|
+
# compute an overall scaling for both values
|
|
265
|
+
x_exponent = max(
|
|
266
|
+
int(f"{x:e}".split("e")[1]),
|
|
267
|
+
int(f"{err:e}".split("e")[1]) + 1,
|
|
268
|
+
)
|
|
269
|
+
# for readability try and show values close to 1 with no exponent
|
|
270
|
+
hide_exponent = (
|
|
271
|
+
# nicer showing 0.xxx(yy) than x.xx(yy)e-1
|
|
272
|
+
(x_exponent in (0, -1))
|
|
273
|
+
or
|
|
274
|
+
# also nicer showing xx.xx(yy) than x.xxx(yy)e+1
|
|
275
|
+
((x_exponent == +1) and (err < abs(x / 10)))
|
|
276
|
+
)
|
|
277
|
+
if hide_exponent:
|
|
278
|
+
suffix = ""
|
|
279
|
+
else:
|
|
280
|
+
x = x / 10**x_exponent
|
|
281
|
+
err = err / 10**x_exponent
|
|
282
|
+
suffix = f"e{x_exponent:+03d}"
|
|
283
|
+
|
|
284
|
+
# work out how many digits to print
|
|
285
|
+
# format the main number and bracketed error
|
|
286
|
+
mantissa, exponent = f"{err:.1e}".split("e")
|
|
287
|
+
mantissa, exponent = mantissa.replace(".", ""), int(exponent)
|
|
288
|
+
return f"{x:.{abs(exponent) + 1}f}({mantissa}){suffix}"
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def save_to_disk(obj, fname, **dump_opts):
|
|
292
|
+
"""Save an object to disk using joblib.dump."""
|
|
293
|
+
import joblib
|
|
294
|
+
|
|
295
|
+
return joblib.dump(obj, fname, **dump_opts)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def load_from_disk(fname, **load_opts):
|
|
299
|
+
"""Load an object form disk using joblib.load."""
|
|
300
|
+
import joblib
|
|
301
|
+
|
|
302
|
+
return joblib.load(fname, **load_opts)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class Verbosify: # pragma: no cover
|
|
306
|
+
"""Decorator for making functions print their inputs. Simply for
|
|
307
|
+
illustrating a MPI example in the docs.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def __init__(self, fn, highlight=None, mpi=False):
|
|
311
|
+
self.fn = fn
|
|
312
|
+
self.highlight = highlight
|
|
313
|
+
self.mpi = mpi
|
|
314
|
+
|
|
315
|
+
def __call__(self, *args, **kwargs):
|
|
316
|
+
if self.mpi:
|
|
317
|
+
from mpi4py import MPI
|
|
318
|
+
|
|
319
|
+
pre_msg = f"{MPI.COMM_WORLD.Get_rank()}: "
|
|
320
|
+
else:
|
|
321
|
+
pre_msg = ""
|
|
322
|
+
|
|
323
|
+
if self.highlight is None:
|
|
324
|
+
print(f"{pre_msg} args {args}, kwargs {kwargs}")
|
|
325
|
+
else:
|
|
326
|
+
print(f"{pre_msg}{self.highlight}={kwargs[self.highlight]}")
|
|
327
|
+
return self.fn(*args, **kwargs)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class oset:
|
|
331
|
+
"""An ordered set which stores elements as the keys of dict (ordered as of
|
|
332
|
+
python 3.6). 'A few times' slower than using a set directly for small
|
|
333
|
+
sizes, but makes everything deterministic.
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
__slots__ = ("_d",)
|
|
337
|
+
|
|
338
|
+
def __init__(self, it=()):
|
|
339
|
+
self._d = dict.fromkeys(it)
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def _from_dict(cls, d):
|
|
343
|
+
obj = object.__new__(oset)
|
|
344
|
+
obj._d = d
|
|
345
|
+
return obj
|
|
346
|
+
|
|
347
|
+
@classmethod
|
|
348
|
+
def from_dict(cls, d):
|
|
349
|
+
"""Public method makes sure to copy incoming dictionary."""
|
|
350
|
+
return oset._from_dict(d.copy())
|
|
351
|
+
|
|
352
|
+
def copy(self):
|
|
353
|
+
return oset.from_dict(self._d)
|
|
354
|
+
|
|
355
|
+
def __deepcopy__(self, memo):
|
|
356
|
+
# always use hashable entries so just take normal copy
|
|
357
|
+
new = self.copy()
|
|
358
|
+
memo[id(self)] = new
|
|
359
|
+
return new
|
|
360
|
+
|
|
361
|
+
def add(self, k):
|
|
362
|
+
self._d[k] = None
|
|
363
|
+
|
|
364
|
+
def discard(self, k):
|
|
365
|
+
self._d.pop(k, None)
|
|
366
|
+
|
|
367
|
+
def remove(self, k):
|
|
368
|
+
del self._d[k]
|
|
369
|
+
|
|
370
|
+
def clear(self):
|
|
371
|
+
self._d.clear()
|
|
372
|
+
|
|
373
|
+
def update(self, *others):
|
|
374
|
+
for o in others:
|
|
375
|
+
try:
|
|
376
|
+
# oset
|
|
377
|
+
self._d.update(o._d)
|
|
378
|
+
except AttributeError:
|
|
379
|
+
# iterable
|
|
380
|
+
for k in o:
|
|
381
|
+
self._d[k] = None
|
|
382
|
+
|
|
383
|
+
def union(self, *others):
|
|
384
|
+
u = self.copy()
|
|
385
|
+
u.update(*others)
|
|
386
|
+
return u
|
|
387
|
+
|
|
388
|
+
def intersection_update(self, *others):
|
|
389
|
+
if len(others) > 1:
|
|
390
|
+
si = set.intersection(*(set(o._d) for o in others))
|
|
391
|
+
else:
|
|
392
|
+
si = others[0]._d
|
|
393
|
+
self._d = {k: None for k in self._d if k in si}
|
|
394
|
+
|
|
395
|
+
def intersection(self, *others):
|
|
396
|
+
n_others = len(others)
|
|
397
|
+
if n_others == 0:
|
|
398
|
+
return self.copy()
|
|
399
|
+
elif n_others == 1:
|
|
400
|
+
si = others[0]._d
|
|
401
|
+
else:
|
|
402
|
+
si = set.intersection(*(set(o._d) for o in others))
|
|
403
|
+
return oset._from_dict({k: None for k in self._d if k in si})
|
|
404
|
+
|
|
405
|
+
def difference_update(self, *others):
|
|
406
|
+
if len(others) > 1:
|
|
407
|
+
su = set.union(*(set(o._d) for o in others))
|
|
408
|
+
else:
|
|
409
|
+
su = others[0]._d
|
|
410
|
+
self._d = {k: None for k in self._d if k not in su}
|
|
411
|
+
|
|
412
|
+
def difference(self, *others):
|
|
413
|
+
if len(others) > 1:
|
|
414
|
+
su = set.union(*(set(o._d) for o in others))
|
|
415
|
+
else:
|
|
416
|
+
su = others[0]._d
|
|
417
|
+
return oset._from_dict({k: None for k in self._d if k not in su})
|
|
418
|
+
|
|
419
|
+
def popleft(self):
|
|
420
|
+
k = next(iter(self._d))
|
|
421
|
+
del self._d[k]
|
|
422
|
+
return k
|
|
423
|
+
|
|
424
|
+
def popright(self):
|
|
425
|
+
return self._d.popitem()[0]
|
|
426
|
+
|
|
427
|
+
pop = popright
|
|
428
|
+
|
|
429
|
+
def __eq__(self, other):
|
|
430
|
+
if isinstance(other, oset):
|
|
431
|
+
return self._d == other._d
|
|
432
|
+
return False
|
|
433
|
+
|
|
434
|
+
def __or__(self, other):
|
|
435
|
+
return self.union(other)
|
|
436
|
+
|
|
437
|
+
def __ior__(self, other):
|
|
438
|
+
self.update(other)
|
|
439
|
+
return self
|
|
440
|
+
|
|
441
|
+
def __and__(self, other):
|
|
442
|
+
return self.intersection(other)
|
|
443
|
+
|
|
444
|
+
def __iand__(self, other):
|
|
445
|
+
self.intersection_update(other)
|
|
446
|
+
return self
|
|
447
|
+
|
|
448
|
+
def __sub__(self, other):
|
|
449
|
+
return self.difference(other)
|
|
450
|
+
|
|
451
|
+
def __isub__(self, other):
|
|
452
|
+
self.difference_update(other)
|
|
453
|
+
return self
|
|
454
|
+
|
|
455
|
+
def __len__(self):
|
|
456
|
+
return self._d.__len__()
|
|
457
|
+
|
|
458
|
+
def __iter__(self):
|
|
459
|
+
return self._d.__iter__()
|
|
460
|
+
|
|
461
|
+
def __contains__(self, x):
|
|
462
|
+
return self._d.__contains__(x)
|
|
463
|
+
|
|
464
|
+
def __repr__(self):
|
|
465
|
+
return f"oset({list(self._d)})"
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class LRU(collections.OrderedDict):
|
|
469
|
+
"""Least recently used dict, which evicts old items. Taken from python
|
|
470
|
+
collections OrderedDict docs.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, maxsize, *args, **kwds):
|
|
474
|
+
self.maxsize = maxsize
|
|
475
|
+
super().__init__(*args, **kwds)
|
|
476
|
+
|
|
477
|
+
def __getitem__(self, key):
|
|
478
|
+
value = super().__getitem__(key)
|
|
479
|
+
self.move_to_end(key)
|
|
480
|
+
return value
|
|
481
|
+
|
|
482
|
+
def __setitem__(self, key, value):
|
|
483
|
+
if key in self:
|
|
484
|
+
self.move_to_end(key)
|
|
485
|
+
super().__setitem__(key, value)
|
|
486
|
+
if len(self) > self.maxsize:
|
|
487
|
+
oldest = next(iter(self))
|
|
488
|
+
del self[oldest]
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class ExponentialGeometricRollingDiffMean:
|
|
492
|
+
def __init__(self, factor=1 / 3, initial=1.0):
|
|
493
|
+
self.x_prev = None
|
|
494
|
+
self.dx = None
|
|
495
|
+
self.value = initial
|
|
496
|
+
self.factor = factor
|
|
497
|
+
|
|
498
|
+
def update(self, x):
|
|
499
|
+
if self.x_prev is not None:
|
|
500
|
+
# get the absolute change
|
|
501
|
+
self.dx = abs(x - self.x_prev)
|
|
502
|
+
# compute
|
|
503
|
+
self.value = self.value ** (1 - self.factor) * self.dx**self.factor
|
|
504
|
+
self.x_prev = x
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def gen_bipartitions(it):
|
|
508
|
+
"""Generate all unique bipartitions of ``it``. Unique meaning
|
|
509
|
+
``(1, 2), (3, 4)`` is considered the same as ``(3, 4), (1, 2)``.
|
|
510
|
+
"""
|
|
511
|
+
n = len(it)
|
|
512
|
+
if n:
|
|
513
|
+
for i in range(1, 2 ** (n - 1)):
|
|
514
|
+
bitstring_repr = f"{i:0>{n}b}"
|
|
515
|
+
l, r = [], []
|
|
516
|
+
for b, x in zip(bitstring_repr, it):
|
|
517
|
+
(l if b == "0" else r).append(x)
|
|
518
|
+
yield l, r
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
TREE_MAP_REGISTRY = {}
|
|
522
|
+
TREE_APPLY_REGISTRY = {}
|
|
523
|
+
TREE_ITER_REGISTRY = {}
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def tree_register_container(cls, mapper, iterator, applier):
|
|
527
|
+
"""Register a new container type for use with ``tree_map`` and
|
|
528
|
+
``tree_apply``.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
cls : type
|
|
533
|
+
The container type to register.
|
|
534
|
+
mapper : callable
|
|
535
|
+
A function that takes ``f``, ``tree`` and ``is_leaf`` and returns a new
|
|
536
|
+
tree of type ``cls`` with ``f`` applied to all leaves.
|
|
537
|
+
applier : callable
|
|
538
|
+
A function that takes ``f``, ``tree`` and ``is_leaf`` and applies ``f``
|
|
539
|
+
to all leaves in ``tree``.
|
|
540
|
+
"""
|
|
541
|
+
TREE_MAP_REGISTRY[cls] = mapper
|
|
542
|
+
TREE_ITER_REGISTRY[cls] = iterator
|
|
543
|
+
TREE_APPLY_REGISTRY[cls] = applier
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
IS_CONTAINER_CACHE = {}
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def is_not_container(x):
|
|
550
|
+
"""The default function to determine if an object is a leaf. This simply
|
|
551
|
+
checks if the object is an instance of any of the registered container
|
|
552
|
+
types.
|
|
553
|
+
"""
|
|
554
|
+
try:
|
|
555
|
+
return IS_CONTAINER_CACHE[x.__class__]
|
|
556
|
+
except KeyError:
|
|
557
|
+
isleaf = not any(isinstance(x, cls) for cls in TREE_MAP_REGISTRY)
|
|
558
|
+
IS_CONTAINER_CACHE[x.__class__] = isleaf
|
|
559
|
+
return isleaf
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _tmap_identity(f, tree, is_leaf):
|
|
563
|
+
return tree
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
TREE_MAPPER_CACHE = {}
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def tree_map(f, tree, is_leaf=is_not_container):
|
|
570
|
+
"""Map ``f`` over all leaves in ``tree``, returning a new pytree.
|
|
571
|
+
|
|
572
|
+
Parameters
|
|
573
|
+
----------
|
|
574
|
+
f : callable
|
|
575
|
+
A function to apply to all leaves in ``tree``.
|
|
576
|
+
tree : pytree
|
|
577
|
+
A nested sequence of tuples, lists, dicts and other objects.
|
|
578
|
+
is_leaf : callable
|
|
579
|
+
A function to determine if an object is a leaf, ``f`` is only applied
|
|
580
|
+
to objects for which ``is_leaf(x)`` returns ``True``.
|
|
581
|
+
|
|
582
|
+
Returns
|
|
583
|
+
-------
|
|
584
|
+
pytree
|
|
585
|
+
"""
|
|
586
|
+
if is_leaf(tree):
|
|
587
|
+
return f(tree)
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
return TREE_MAPPER_CACHE[tree.__class__](f, tree, is_leaf)
|
|
591
|
+
except KeyError:
|
|
592
|
+
# reverse so later registered classes take precedence
|
|
593
|
+
for cls, mapper in reversed(TREE_MAP_REGISTRY.items()):
|
|
594
|
+
if isinstance(tree, cls):
|
|
595
|
+
break
|
|
596
|
+
else:
|
|
597
|
+
# neither leaf nor container -> simply return it
|
|
598
|
+
mapper = _tmap_identity
|
|
599
|
+
TREE_MAPPER_CACHE[tree.__class__] = mapper
|
|
600
|
+
return mapper(f, tree, is_leaf)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def empty(tree, is_leaf):
|
|
604
|
+
return iter(())
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
TREE_ITER_CACHE = {}
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def tree_iter(tree, is_leaf=is_not_container):
|
|
611
|
+
"""Iterate over all leaves in ``tree``.
|
|
612
|
+
|
|
613
|
+
Parameters
|
|
614
|
+
----------
|
|
615
|
+
f : callable
|
|
616
|
+
A function to apply to all leaves in ``tree``.
|
|
617
|
+
tree : pytree
|
|
618
|
+
A nested sequence of tuples, lists, dicts and other objects.
|
|
619
|
+
is_leaf : callable
|
|
620
|
+
A function to determine if an object is a leaf, ``f`` is only applied
|
|
621
|
+
to objects for which ``is_leaf(x)`` returns ``True``.
|
|
622
|
+
"""
|
|
623
|
+
if is_leaf(tree):
|
|
624
|
+
yield tree
|
|
625
|
+
return
|
|
626
|
+
|
|
627
|
+
try:
|
|
628
|
+
yield from TREE_ITER_CACHE[tree.__class__](tree, is_leaf)
|
|
629
|
+
except KeyError:
|
|
630
|
+
# reverse so later registered classes take precedence
|
|
631
|
+
for cls, iterator in reversed(TREE_ITER_REGISTRY.items()):
|
|
632
|
+
if isinstance(tree, cls):
|
|
633
|
+
break
|
|
634
|
+
else:
|
|
635
|
+
# neither leaf nor container -> simply ignore it
|
|
636
|
+
iterator = empty
|
|
637
|
+
TREE_ITER_CACHE[tree.__class__] = iterator
|
|
638
|
+
yield from iterator(tree, is_leaf)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def nothing(f, tree, is_leaf):
|
|
642
|
+
pass
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
TREE_APPLIER_CACHE = {}
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def tree_apply(f, tree, is_leaf=is_not_container):
|
|
649
|
+
"""Apply ``f`` to all leaves in ``tree``, no new pytree is built.
|
|
650
|
+
|
|
651
|
+
Parameters
|
|
652
|
+
----------
|
|
653
|
+
f : callable
|
|
654
|
+
A function to apply to all leaves in ``tree``.
|
|
655
|
+
tree : pytree
|
|
656
|
+
A nested sequence of tuples, lists, dicts and other objects.
|
|
657
|
+
is_leaf : callable
|
|
658
|
+
A function to determine if an object is a leaf, ``f`` is only applied
|
|
659
|
+
to objects for which ``is_leaf(x)`` returns ``True``.
|
|
660
|
+
"""
|
|
661
|
+
if is_leaf(tree):
|
|
662
|
+
f(tree)
|
|
663
|
+
return
|
|
664
|
+
|
|
665
|
+
try:
|
|
666
|
+
TREE_APPLIER_CACHE[tree.__class__](f, tree, is_leaf)
|
|
667
|
+
except KeyError:
|
|
668
|
+
# reverse so later registered classes take precedence
|
|
669
|
+
for cls, applier in reversed(TREE_APPLY_REGISTRY.items()):
|
|
670
|
+
if isinstance(tree, cls):
|
|
671
|
+
break
|
|
672
|
+
else:
|
|
673
|
+
# neither leaf nor container -> simply ignore it
|
|
674
|
+
applier = nothing
|
|
675
|
+
TREE_APPLIER_CACHE[tree.__class__] = applier
|
|
676
|
+
applier(f, tree, is_leaf)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class Leaf:
|
|
680
|
+
__slots__ = ()
|
|
681
|
+
|
|
682
|
+
def __repr__(self):
|
|
683
|
+
return "Leaf"
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
Leaf = Leaf()
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def is_leaf_object(x):
|
|
690
|
+
return x is Leaf
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def tree_flatten(tree, get_ref=False, is_leaf=is_not_container):
|
|
694
|
+
"""Flatten ``tree`` into a list of leaves.
|
|
695
|
+
|
|
696
|
+
Parameters
|
|
697
|
+
----------
|
|
698
|
+
tree : pytree
|
|
699
|
+
A nested sequence of tuples, lists, dicts and other objects.
|
|
700
|
+
is_leaf : callable
|
|
701
|
+
A function to determine if an object is a leaf, only objects for which
|
|
702
|
+
``is_leaf(x)`` returns ``True`` are returned in the flattened list.
|
|
703
|
+
|
|
704
|
+
Returns
|
|
705
|
+
-------
|
|
706
|
+
objs : list
|
|
707
|
+
The flattened list of leaf objects.
|
|
708
|
+
(ref_tree) : pytree
|
|
709
|
+
If ``get_ref`` is ``True``, a reference tree, with leaves of type
|
|
710
|
+
``Leaf``, is returned which can be used to reconstruct the original
|
|
711
|
+
tree.
|
|
712
|
+
"""
|
|
713
|
+
objs = []
|
|
714
|
+
if get_ref:
|
|
715
|
+
# return a new tree with Leaf leaves, as well as the flattened list
|
|
716
|
+
|
|
717
|
+
def f(x):
|
|
718
|
+
objs.append(x)
|
|
719
|
+
return Leaf
|
|
720
|
+
|
|
721
|
+
ref_tree = tree_map(f, tree, is_leaf)
|
|
722
|
+
return objs, ref_tree
|
|
723
|
+
else:
|
|
724
|
+
tree_apply(objs.append, tree, is_leaf)
|
|
725
|
+
return objs
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def tree_unflatten(objs, tree, is_leaf=is_leaf_object):
|
|
729
|
+
"""Unflatten ``objs`` into a pytree of the same structure as ``tree``.
|
|
730
|
+
|
|
731
|
+
Parameters
|
|
732
|
+
----------
|
|
733
|
+
objs : sequence
|
|
734
|
+
A sequence of objects to be unflattened into a pytree.
|
|
735
|
+
tree : pytree
|
|
736
|
+
A nested sequence of tuples, lists, dicts and other objects, the objs
|
|
737
|
+
will be inserted into a new pytree of the same structure.
|
|
738
|
+
is_leaf : callable
|
|
739
|
+
A function to determine if an object is a leaf, only objects for which
|
|
740
|
+
``is_leaf(x)`` returns ``True`` will have the next item from ``objs``
|
|
741
|
+
inserted. By default checks for the ``Leaf`` object inserted by
|
|
742
|
+
``tree_flatten(..., get_ref=True)``.
|
|
743
|
+
|
|
744
|
+
Returns
|
|
745
|
+
-------
|
|
746
|
+
pytree
|
|
747
|
+
"""
|
|
748
|
+
objs = iter(objs)
|
|
749
|
+
return tree_map(lambda _: next(objs), tree, is_leaf)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def tree_map_tuple(f, tree, is_leaf):
|
|
753
|
+
return tuple(tree_map(f, x, is_leaf) for x in tree)
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def tree_iter_tuple(tree, is_leaf):
|
|
757
|
+
for x in tree:
|
|
758
|
+
yield from tree_iter(x, is_leaf)
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
def tree_apply_tuple(f, tree, is_leaf):
|
|
762
|
+
for x in tree:
|
|
763
|
+
tree_apply(f, x, is_leaf)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
tree_register_container(
|
|
767
|
+
tuple, tree_map_tuple, tree_iter_tuple, tree_apply_tuple
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def tree_map_list(f, tree, is_leaf):
|
|
772
|
+
return [tree_map(f, x, is_leaf) for x in tree]
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def tree_iter_list(tree, is_leaf):
|
|
776
|
+
for x in tree:
|
|
777
|
+
yield from tree_iter(x, is_leaf)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def tree_apply_list(f, tree, is_leaf):
|
|
781
|
+
for x in tree:
|
|
782
|
+
tree_apply(f, x, is_leaf)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
tree_register_container(list, tree_map_list, tree_iter_list, tree_apply_list)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def tree_map_dict(f, tree, is_leaf):
|
|
789
|
+
return {k: tree_map(f, v, is_leaf) for k, v in tree.items()}
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def tree_iter_dict(tree, is_leaf):
|
|
793
|
+
for v in tree.values():
|
|
794
|
+
yield from tree_iter(v, is_leaf)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def tree_apply_dict(f, tree, is_leaf):
|
|
798
|
+
for v in tree.values():
|
|
799
|
+
tree_apply(f, v, is_leaf)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
tree_register_container(dict, tree_map_dict, tree_iter_dict, tree_apply_dict)
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
# a style to use for matplotlib that works with light and dark backgrounds
|
|
806
|
+
NEUTRAL_STYLE = {
|
|
807
|
+
"axes.edgecolor": (0.5, 0.5, 0.5),
|
|
808
|
+
"axes.facecolor": (0, 0, 0, 0),
|
|
809
|
+
"axes.grid": True,
|
|
810
|
+
"axes.labelcolor": (0.5, 0.5, 0.5),
|
|
811
|
+
"axes.spines.right": False,
|
|
812
|
+
"axes.spines.top": False,
|
|
813
|
+
"figure.facecolor": (0, 0, 0, 0),
|
|
814
|
+
"grid.alpha": 0.1,
|
|
815
|
+
"grid.color": (0.5, 0.5, 0.5),
|
|
816
|
+
"legend.frameon": False,
|
|
817
|
+
"text.color": (0.5, 0.5, 0.5),
|
|
818
|
+
"xtick.color": (0.5, 0.5, 0.5),
|
|
819
|
+
"xtick.minor.visible": True,
|
|
820
|
+
"ytick.color": (0.5, 0.5, 0.5),
|
|
821
|
+
"ytick.minor.visible": True,
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def default_to_neutral_style(fn):
|
|
826
|
+
"""Wrap a function or method to use the neutral style by default."""
|
|
827
|
+
|
|
828
|
+
@functools.wraps(fn)
|
|
829
|
+
def wrapper(*args, style="neutral", show_and_close=True, **kwargs):
|
|
830
|
+
import matplotlib.pyplot as plt
|
|
831
|
+
|
|
832
|
+
if style == "neutral":
|
|
833
|
+
style = NEUTRAL_STYLE
|
|
834
|
+
elif not style:
|
|
835
|
+
style = {}
|
|
836
|
+
|
|
837
|
+
with plt.style.context(style):
|
|
838
|
+
out = fn(*args, **kwargs)
|
|
839
|
+
|
|
840
|
+
if show_and_close:
|
|
841
|
+
plt.show()
|
|
842
|
+
plt.close()
|
|
843
|
+
|
|
844
|
+
return out
|
|
845
|
+
|
|
846
|
+
return wrapper
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def autocorrect_kwargs(func=None, valid_kwargs=None):
|
|
850
|
+
"""A decorator that suggests the right keyword arguments if you get them
|
|
851
|
+
wrong. Useful for functions with many specific options.
|
|
852
|
+
|
|
853
|
+
Parameters
|
|
854
|
+
----------
|
|
855
|
+
func : callable, optional
|
|
856
|
+
The function to decorate.
|
|
857
|
+
valid_kwargs : sequence[str], optional
|
|
858
|
+
The valid keyword arguments for ``func``, if not given these are
|
|
859
|
+
inferred from the function signature.
|
|
860
|
+
"""
|
|
861
|
+
if func is None:
|
|
862
|
+
# decorator with options
|
|
863
|
+
return functools.partial(autocorrect_kwargs, valid_kwargs=valid_kwargs)
|
|
864
|
+
|
|
865
|
+
if valid_kwargs is None:
|
|
866
|
+
import inspect
|
|
867
|
+
|
|
868
|
+
sig = inspect.signature(func)
|
|
869
|
+
params = sig.parameters
|
|
870
|
+
valid_kwargs = set(params.keys())
|
|
871
|
+
else:
|
|
872
|
+
valid_kwargs = set(valid_kwargs)
|
|
873
|
+
|
|
874
|
+
@functools.wraps(func)
|
|
875
|
+
def wrapped(*args, **kwargs):
|
|
876
|
+
wrong_opts = {kw for kw in kwargs if kw not in valid_kwargs}
|
|
877
|
+
if wrong_opts:
|
|
878
|
+
import difflib
|
|
879
|
+
|
|
880
|
+
right_opts = (
|
|
881
|
+
difflib.get_close_matches(opt, valid_kwargs, n=3)
|
|
882
|
+
for opt in wrong_opts
|
|
883
|
+
)
|
|
884
|
+
msg = "Option(s) {} not valid.\n Did you mean: {}?".format(
|
|
885
|
+
wrong_opts, ", ".join(map(str, right_opts))
|
|
886
|
+
)
|
|
887
|
+
print(msg)
|
|
888
|
+
raise ValueError(msg)
|
|
889
|
+
|
|
890
|
+
return func(*args, **kwargs)
|
|
891
|
+
|
|
892
|
+
return wrapped
|