Trajectree 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +9 -9
  5. trajectree/fock_optics/outputs.py +10 -6
  6. trajectree/fock_optics/utils.py +9 -6
  7. trajectree/sequence/swap.py +5 -4
  8. trajectree/trajectory.py +5 -4
  9. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/METADATA +2 -3
  10. trajectree-0.0.3.dist-info/RECORD +16 -0
  11. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  12. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  13. trajectree/quimb/docs/conf.py +0 -158
  14. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  15. trajectree/quimb/quimb/__init__.py +0 -507
  16. trajectree/quimb/quimb/calc.py +0 -1491
  17. trajectree/quimb/quimb/core.py +0 -2279
  18. trajectree/quimb/quimb/evo.py +0 -712
  19. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  20. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  21. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  22. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  23. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  24. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  25. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  26. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  27. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  28. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  29. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  30. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  31. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  32. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  33. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  34. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  35. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  36. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  37. trajectree/quimb/quimb/gates.py +0 -36
  38. trajectree/quimb/quimb/gen/__init__.py +0 -2
  39. trajectree/quimb/quimb/gen/operators.py +0 -1167
  40. trajectree/quimb/quimb/gen/rand.py +0 -713
  41. trajectree/quimb/quimb/gen/states.py +0 -479
  42. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  43. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  44. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  45. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  46. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  47. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  48. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  49. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  50. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  51. trajectree/quimb/quimb/schematic.py +0 -1518
  52. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  53. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  54. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  55. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  56. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  57. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  58. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  59. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  60. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  61. trajectree/quimb/quimb/tensor/interface.py +0 -114
  62. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  63. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  64. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  65. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  66. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  67. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  68. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  69. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  70. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  71. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  72. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  74. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  75. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  76. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  77. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  78. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  79. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  80. trajectree/quimb/quimb/utils.py +0 -892
  81. trajectree/quimb/tests/__init__.py +0 -0
  82. trajectree/quimb/tests/test_accel.py +0 -501
  83. trajectree/quimb/tests/test_calc.py +0 -788
  84. trajectree/quimb/tests/test_core.py +0 -847
  85. trajectree/quimb/tests/test_evo.py +0 -565
  86. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  87. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  88. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  89. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  90. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  91. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  92. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  93. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  94. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  95. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  96. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  97. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  103. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  104. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  105. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  106. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  107. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  108. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  109. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  110. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  111. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  112. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  113. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  114. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  115. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  116. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  117. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  118. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  119. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  120. trajectree/quimb/tests/test_utils.py +0 -85
  121. trajectree-0.0.1.dist-info/RECORD +0 -126
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/WHEEL +0 -0
  123. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/licenses/LICENSE +0 -0
  124. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,892 +0,0 @@
1
- """Misc utility functions."""
2
-
3
- import collections
4
- import functools
5
- import itertools
6
- from importlib.util import find_spec
7
-
8
- try:
9
- import cytoolz
10
-
11
- last = cytoolz.last
12
- concat = cytoolz.concat
13
- frequencies = cytoolz.frequencies
14
- partition_all = cytoolz.partition_all
15
- merge_with = cytoolz.merge_with
16
- valmap = cytoolz.valmap
17
- partitionby = cytoolz.partitionby
18
- concatv = cytoolz.concatv
19
- partition = cytoolz.partition
20
- partition_all = cytoolz.partition_all
21
- compose = cytoolz.compose
22
- identity = cytoolz.identity
23
- isiterable = cytoolz.isiterable
24
- unique = cytoolz.unique
25
- keymap = cytoolz.keymap
26
- except ImportError:
27
- import toolz
28
-
29
- last = toolz.last
30
- concat = toolz.concat
31
- frequencies = toolz.frequencies
32
- partition_all = toolz.partition_all
33
- merge_with = toolz.merge_with
34
- valmap = toolz.valmap
35
- partition = toolz.partition
36
- partitionby = toolz.partitionby
37
- concatv = toolz.concatv
38
- partition_all = toolz.partition_all
39
- compose = toolz.compose
40
- identity = toolz.identity
41
- isiterable = toolz.isiterable
42
- unique = toolz.unique
43
- keymap = toolz.keymap
44
-
45
-
46
- _CHECK_OPT_MSG = "Option `{}` should be one of {}, but got '{}'."
47
-
48
-
49
- def check_opt(name, value, valid):
50
- """Check whether ``value`` takes one of ``valid`` options, and raise an
51
- informative error if not.
52
- """
53
- if value not in valid:
54
- raise ValueError(_CHECK_OPT_MSG.format(name, valid, value))
55
-
56
-
57
- def find_library(x):
58
- """Check if library is installed.
59
-
60
- Parameters
61
- ----------
62
- x : str
63
- Name of library
64
-
65
- Returns
66
- -------
67
- bool
68
- If library is available.
69
- """
70
- return find_spec(x) is not None
71
-
72
-
73
- def raise_cant_find_library_function(x, extra_msg=None):
74
- """Return function to flag up a missing necessary library.
75
-
76
- This is simplify the task of flagging optional dependencies only at the
77
- point at which they are needed, and not earlier.
78
-
79
- Parameters
80
- ----------
81
- x : str
82
- Name of library
83
- extra_msg : str, optional
84
- Make the function print this message as well, for additional
85
- information.
86
-
87
- Returns
88
- -------
89
- callable
90
- A mock function that when called, raises an import error specifying
91
- the required library.
92
- """
93
-
94
- def function_that_will_raise(*_, **__):
95
- error_msg = f"The library {x} is not installed. "
96
- if extra_msg is not None:
97
- error_msg += extra_msg
98
- raise ImportError(error_msg)
99
-
100
- return function_that_will_raise
101
-
102
-
103
- FOUND_TQDM = find_library("tqdm")
104
- if FOUND_TQDM:
105
- from tqdm import tqdm
106
-
107
- class continuous_progbar(tqdm):
108
- """A continuous version of tqdm, so that it can be updated with a float
109
- within some pre-given range, rather than a number of steps.
110
-
111
- Parameters
112
- ----------
113
- args : (stop) or (start, stop)
114
- Stopping point (and starting point if ``len(args) == 2``) of window
115
- within which to evaluate progress.
116
- total : int
117
- The number of steps to represent the continuous progress with.
118
- kwargs
119
- Supplied to ``tqdm.tqdm``
120
- """
121
-
122
- def __init__(self, *args, total=100, **kwargs):
123
- """ """
124
- kwargs.setdefault("ascii", True)
125
- super(continuous_progbar, self).__init__(
126
- total=total, unit="%", **kwargs
127
- )
128
-
129
- if len(args) == 2:
130
- self.start, self.stop = args
131
- else:
132
- self.start, self.stop = 0, args[0]
133
-
134
- self.range = self.stop - self.start
135
- self.step = 1
136
-
137
- def cupdate(self, x):
138
- """'Continuous' update of progress bar.
139
-
140
- Parameters
141
- ----------
142
- x : float
143
- Current position within the range ``[self.start, self.stop]``.
144
- """
145
- num_update = int(
146
- (self.total + 1) * (x - self.start) / self.range - self.step
147
- )
148
- if num_update > 0:
149
- self.update(num_update)
150
- self.step += num_update
151
-
152
- def progbar(*args, **kwargs):
153
- kwargs.setdefault("ascii", True)
154
- return tqdm(*args, **kwargs)
155
-
156
- else: # pragma: no cover
157
- extra_msg = "This is needed to show progress bars."
158
- progbar = raise_cant_find_library_function("tqdm", extra_msg)
159
- continuous_progbar = raise_cant_find_library_function("tqdm", extra_msg)
160
-
161
-
162
- def deprecated(fn, old_name, new_name):
163
- """Mark a function as deprecated, and indicate the new name."""
164
-
165
- @functools.wraps(fn)
166
- def new_fn(*args, **kwargs):
167
- import warnings
168
-
169
- warnings.warn(
170
- f"The {old_name} function is deprecated in favor of {new_name}",
171
- Warning,
172
- )
173
- return fn(*args, **kwargs)
174
-
175
- return new_fn
176
-
177
-
178
- def int2tup(x):
179
- return (
180
- x if isinstance(x, tuple) else (x,) if isinstance(x, int) else tuple(x)
181
- )
182
-
183
-
184
- def ensure_dict(x):
185
- """Make sure ``x`` is a ``dict``, creating an empty one if ``x is None``."""
186
- if x is None:
187
- return {}
188
- return dict(x)
189
-
190
-
191
- def pairwise(iterable):
192
- """Iterate over each pair of neighbours in ``iterable``."""
193
- a, b = itertools.tee(iterable)
194
- next(b, None)
195
- return zip(a, b)
196
-
197
-
198
- def print_multi_line(*lines, max_width=None):
199
- """Print multiple lines, with a maximum width."""
200
- if max_width is None:
201
- import shutil
202
-
203
- max_width, _ = shutil.get_terminal_size()
204
-
205
- max_line_lenth = max(len(ln) for ln in lines)
206
-
207
- if max_line_lenth <= max_width:
208
- for ln in lines:
209
- print(ln)
210
-
211
- else: # pragma: no cover
212
- max_width -= 10 # for ellipses and pad
213
- n_lines = len(lines)
214
- n_blocks = (max_line_lenth - 1) // max_width + 1
215
-
216
- for i in range(n_blocks):
217
- if i == 0:
218
- for j, l in enumerate(lines):
219
- print(
220
- "..." if j == n_lines // 2 else " ",
221
- l[i * max_width : (i + 1) * max_width],
222
- "..." if j == n_lines // 2 else " ",
223
- )
224
- print(("{:^" + str(max_width) + "}").format("..."))
225
- elif i == n_blocks - 1:
226
- for ln in lines:
227
- print(" ", ln[i * max_width : (i + 1) * max_width])
228
- else:
229
- for j, ln in enumerate(lines):
230
- print(
231
- "..." if j == n_lines // 2 else " ",
232
- ln[i * max_width : (i + 1) * max_width],
233
- "..." if j == n_lines // 2 else " ",
234
- )
235
- print(("{:^" + str(max_width) + "}").format("..."))
236
-
237
-
238
- def format_number_with_error(x, err):
239
- """Given ``x`` with error ``err``, format a string showing the relevant
240
- digits of ``x`` with two significant digits of the error bracketed, and
241
- overall exponent if necessary.
242
-
243
- Parameters
244
- ----------
245
- x : float
246
- The value to print.
247
- err : float
248
- The error on ``x``.
249
-
250
- Returns
251
- -------
252
- str
253
-
254
- Examples
255
- --------
256
-
257
- >>> print_number_with_uncertainty(0.1542412, 0.0626653)
258
- '0.154(63)'
259
-
260
- >>> print_number_with_uncertainty(-128124123097, 6424)
261
- '-1.281241231(64)e+11'
262
-
263
- """
264
- # compute an overall scaling for both values
265
- x_exponent = max(
266
- int(f"{x:e}".split("e")[1]),
267
- int(f"{err:e}".split("e")[1]) + 1,
268
- )
269
- # for readability try and show values close to 1 with no exponent
270
- hide_exponent = (
271
- # nicer showing 0.xxx(yy) than x.xx(yy)e-1
272
- (x_exponent in (0, -1))
273
- or
274
- # also nicer showing xx.xx(yy) than x.xxx(yy)e+1
275
- ((x_exponent == +1) and (err < abs(x / 10)))
276
- )
277
- if hide_exponent:
278
- suffix = ""
279
- else:
280
- x = x / 10**x_exponent
281
- err = err / 10**x_exponent
282
- suffix = f"e{x_exponent:+03d}"
283
-
284
- # work out how many digits to print
285
- # format the main number and bracketed error
286
- mantissa, exponent = f"{err:.1e}".split("e")
287
- mantissa, exponent = mantissa.replace(".", ""), int(exponent)
288
- return f"{x:.{abs(exponent) + 1}f}({mantissa}){suffix}"
289
-
290
-
291
- def save_to_disk(obj, fname, **dump_opts):
292
- """Save an object to disk using joblib.dump."""
293
- import joblib
294
-
295
- return joblib.dump(obj, fname, **dump_opts)
296
-
297
-
298
- def load_from_disk(fname, **load_opts):
299
- """Load an object form disk using joblib.load."""
300
- import joblib
301
-
302
- return joblib.load(fname, **load_opts)
303
-
304
-
305
- class Verbosify: # pragma: no cover
306
- """Decorator for making functions print their inputs. Simply for
307
- illustrating a MPI example in the docs.
308
- """
309
-
310
- def __init__(self, fn, highlight=None, mpi=False):
311
- self.fn = fn
312
- self.highlight = highlight
313
- self.mpi = mpi
314
-
315
- def __call__(self, *args, **kwargs):
316
- if self.mpi:
317
- from mpi4py import MPI
318
-
319
- pre_msg = f"{MPI.COMM_WORLD.Get_rank()}: "
320
- else:
321
- pre_msg = ""
322
-
323
- if self.highlight is None:
324
- print(f"{pre_msg} args {args}, kwargs {kwargs}")
325
- else:
326
- print(f"{pre_msg}{self.highlight}={kwargs[self.highlight]}")
327
- return self.fn(*args, **kwargs)
328
-
329
-
330
- class oset:
331
- """An ordered set which stores elements as the keys of dict (ordered as of
332
- python 3.6). 'A few times' slower than using a set directly for small
333
- sizes, but makes everything deterministic.
334
- """
335
-
336
- __slots__ = ("_d",)
337
-
338
- def __init__(self, it=()):
339
- self._d = dict.fromkeys(it)
340
-
341
- @classmethod
342
- def _from_dict(cls, d):
343
- obj = object.__new__(oset)
344
- obj._d = d
345
- return obj
346
-
347
- @classmethod
348
- def from_dict(cls, d):
349
- """Public method makes sure to copy incoming dictionary."""
350
- return oset._from_dict(d.copy())
351
-
352
- def copy(self):
353
- return oset.from_dict(self._d)
354
-
355
- def __deepcopy__(self, memo):
356
- # always use hashable entries so just take normal copy
357
- new = self.copy()
358
- memo[id(self)] = new
359
- return new
360
-
361
- def add(self, k):
362
- self._d[k] = None
363
-
364
- def discard(self, k):
365
- self._d.pop(k, None)
366
-
367
- def remove(self, k):
368
- del self._d[k]
369
-
370
- def clear(self):
371
- self._d.clear()
372
-
373
- def update(self, *others):
374
- for o in others:
375
- try:
376
- # oset
377
- self._d.update(o._d)
378
- except AttributeError:
379
- # iterable
380
- for k in o:
381
- self._d[k] = None
382
-
383
- def union(self, *others):
384
- u = self.copy()
385
- u.update(*others)
386
- return u
387
-
388
- def intersection_update(self, *others):
389
- if len(others) > 1:
390
- si = set.intersection(*(set(o._d) for o in others))
391
- else:
392
- si = others[0]._d
393
- self._d = {k: None for k in self._d if k in si}
394
-
395
- def intersection(self, *others):
396
- n_others = len(others)
397
- if n_others == 0:
398
- return self.copy()
399
- elif n_others == 1:
400
- si = others[0]._d
401
- else:
402
- si = set.intersection(*(set(o._d) for o in others))
403
- return oset._from_dict({k: None for k in self._d if k in si})
404
-
405
- def difference_update(self, *others):
406
- if len(others) > 1:
407
- su = set.union(*(set(o._d) for o in others))
408
- else:
409
- su = others[0]._d
410
- self._d = {k: None for k in self._d if k not in su}
411
-
412
- def difference(self, *others):
413
- if len(others) > 1:
414
- su = set.union(*(set(o._d) for o in others))
415
- else:
416
- su = others[0]._d
417
- return oset._from_dict({k: None for k in self._d if k not in su})
418
-
419
- def popleft(self):
420
- k = next(iter(self._d))
421
- del self._d[k]
422
- return k
423
-
424
- def popright(self):
425
- return self._d.popitem()[0]
426
-
427
- pop = popright
428
-
429
- def __eq__(self, other):
430
- if isinstance(other, oset):
431
- return self._d == other._d
432
- return False
433
-
434
- def __or__(self, other):
435
- return self.union(other)
436
-
437
- def __ior__(self, other):
438
- self.update(other)
439
- return self
440
-
441
- def __and__(self, other):
442
- return self.intersection(other)
443
-
444
- def __iand__(self, other):
445
- self.intersection_update(other)
446
- return self
447
-
448
- def __sub__(self, other):
449
- return self.difference(other)
450
-
451
- def __isub__(self, other):
452
- self.difference_update(other)
453
- return self
454
-
455
- def __len__(self):
456
- return self._d.__len__()
457
-
458
- def __iter__(self):
459
- return self._d.__iter__()
460
-
461
- def __contains__(self, x):
462
- return self._d.__contains__(x)
463
-
464
- def __repr__(self):
465
- return f"oset({list(self._d)})"
466
-
467
-
468
- class LRU(collections.OrderedDict):
469
- """Least recently used dict, which evicts old items. Taken from python
470
- collections OrderedDict docs.
471
- """
472
-
473
- def __init__(self, maxsize, *args, **kwds):
474
- self.maxsize = maxsize
475
- super().__init__(*args, **kwds)
476
-
477
- def __getitem__(self, key):
478
- value = super().__getitem__(key)
479
- self.move_to_end(key)
480
- return value
481
-
482
- def __setitem__(self, key, value):
483
- if key in self:
484
- self.move_to_end(key)
485
- super().__setitem__(key, value)
486
- if len(self) > self.maxsize:
487
- oldest = next(iter(self))
488
- del self[oldest]
489
-
490
-
491
- class ExponentialGeometricRollingDiffMean:
492
- def __init__(self, factor=1 / 3, initial=1.0):
493
- self.x_prev = None
494
- self.dx = None
495
- self.value = initial
496
- self.factor = factor
497
-
498
- def update(self, x):
499
- if self.x_prev is not None:
500
- # get the absolute change
501
- self.dx = abs(x - self.x_prev)
502
- # compute
503
- self.value = self.value ** (1 - self.factor) * self.dx**self.factor
504
- self.x_prev = x
505
-
506
-
507
- def gen_bipartitions(it):
508
- """Generate all unique bipartitions of ``it``. Unique meaning
509
- ``(1, 2), (3, 4)`` is considered the same as ``(3, 4), (1, 2)``.
510
- """
511
- n = len(it)
512
- if n:
513
- for i in range(1, 2 ** (n - 1)):
514
- bitstring_repr = f"{i:0>{n}b}"
515
- l, r = [], []
516
- for b, x in zip(bitstring_repr, it):
517
- (l if b == "0" else r).append(x)
518
- yield l, r
519
-
520
-
521
- TREE_MAP_REGISTRY = {}
522
- TREE_APPLY_REGISTRY = {}
523
- TREE_ITER_REGISTRY = {}
524
-
525
-
526
- def tree_register_container(cls, mapper, iterator, applier):
527
- """Register a new container type for use with ``tree_map`` and
528
- ``tree_apply``.
529
-
530
- Parameters
531
- ----------
532
- cls : type
533
- The container type to register.
534
- mapper : callable
535
- A function that takes ``f``, ``tree`` and ``is_leaf`` and returns a new
536
- tree of type ``cls`` with ``f`` applied to all leaves.
537
- applier : callable
538
- A function that takes ``f``, ``tree`` and ``is_leaf`` and applies ``f``
539
- to all leaves in ``tree``.
540
- """
541
- TREE_MAP_REGISTRY[cls] = mapper
542
- TREE_ITER_REGISTRY[cls] = iterator
543
- TREE_APPLY_REGISTRY[cls] = applier
544
-
545
-
546
- IS_CONTAINER_CACHE = {}
547
-
548
-
549
- def is_not_container(x):
550
- """The default function to determine if an object is a leaf. This simply
551
- checks if the object is an instance of any of the registered container
552
- types.
553
- """
554
- try:
555
- return IS_CONTAINER_CACHE[x.__class__]
556
- except KeyError:
557
- isleaf = not any(isinstance(x, cls) for cls in TREE_MAP_REGISTRY)
558
- IS_CONTAINER_CACHE[x.__class__] = isleaf
559
- return isleaf
560
-
561
-
562
- def _tmap_identity(f, tree, is_leaf):
563
- return tree
564
-
565
-
566
- TREE_MAPPER_CACHE = {}
567
-
568
-
569
- def tree_map(f, tree, is_leaf=is_not_container):
570
- """Map ``f`` over all leaves in ``tree``, returning a new pytree.
571
-
572
- Parameters
573
- ----------
574
- f : callable
575
- A function to apply to all leaves in ``tree``.
576
- tree : pytree
577
- A nested sequence of tuples, lists, dicts and other objects.
578
- is_leaf : callable
579
- A function to determine if an object is a leaf, ``f`` is only applied
580
- to objects for which ``is_leaf(x)`` returns ``True``.
581
-
582
- Returns
583
- -------
584
- pytree
585
- """
586
- if is_leaf(tree):
587
- return f(tree)
588
-
589
- try:
590
- return TREE_MAPPER_CACHE[tree.__class__](f, tree, is_leaf)
591
- except KeyError:
592
- # reverse so later registered classes take precedence
593
- for cls, mapper in reversed(TREE_MAP_REGISTRY.items()):
594
- if isinstance(tree, cls):
595
- break
596
- else:
597
- # neither leaf nor container -> simply return it
598
- mapper = _tmap_identity
599
- TREE_MAPPER_CACHE[tree.__class__] = mapper
600
- return mapper(f, tree, is_leaf)
601
-
602
-
603
- def empty(tree, is_leaf):
604
- return iter(())
605
-
606
-
607
- TREE_ITER_CACHE = {}
608
-
609
-
610
- def tree_iter(tree, is_leaf=is_not_container):
611
- """Iterate over all leaves in ``tree``.
612
-
613
- Parameters
614
- ----------
615
- f : callable
616
- A function to apply to all leaves in ``tree``.
617
- tree : pytree
618
- A nested sequence of tuples, lists, dicts and other objects.
619
- is_leaf : callable
620
- A function to determine if an object is a leaf, ``f`` is only applied
621
- to objects for which ``is_leaf(x)`` returns ``True``.
622
- """
623
- if is_leaf(tree):
624
- yield tree
625
- return
626
-
627
- try:
628
- yield from TREE_ITER_CACHE[tree.__class__](tree, is_leaf)
629
- except KeyError:
630
- # reverse so later registered classes take precedence
631
- for cls, iterator in reversed(TREE_ITER_REGISTRY.items()):
632
- if isinstance(tree, cls):
633
- break
634
- else:
635
- # neither leaf nor container -> simply ignore it
636
- iterator = empty
637
- TREE_ITER_CACHE[tree.__class__] = iterator
638
- yield from iterator(tree, is_leaf)
639
-
640
-
641
- def nothing(f, tree, is_leaf):
642
- pass
643
-
644
-
645
- TREE_APPLIER_CACHE = {}
646
-
647
-
648
- def tree_apply(f, tree, is_leaf=is_not_container):
649
- """Apply ``f`` to all leaves in ``tree``, no new pytree is built.
650
-
651
- Parameters
652
- ----------
653
- f : callable
654
- A function to apply to all leaves in ``tree``.
655
- tree : pytree
656
- A nested sequence of tuples, lists, dicts and other objects.
657
- is_leaf : callable
658
- A function to determine if an object is a leaf, ``f`` is only applied
659
- to objects for which ``is_leaf(x)`` returns ``True``.
660
- """
661
- if is_leaf(tree):
662
- f(tree)
663
- return
664
-
665
- try:
666
- TREE_APPLIER_CACHE[tree.__class__](f, tree, is_leaf)
667
- except KeyError:
668
- # reverse so later registered classes take precedence
669
- for cls, applier in reversed(TREE_APPLY_REGISTRY.items()):
670
- if isinstance(tree, cls):
671
- break
672
- else:
673
- # neither leaf nor container -> simply ignore it
674
- applier = nothing
675
- TREE_APPLIER_CACHE[tree.__class__] = applier
676
- applier(f, tree, is_leaf)
677
-
678
-
679
- class Leaf:
680
- __slots__ = ()
681
-
682
- def __repr__(self):
683
- return "Leaf"
684
-
685
-
686
- Leaf = Leaf()
687
-
688
-
689
- def is_leaf_object(x):
690
- return x is Leaf
691
-
692
-
693
- def tree_flatten(tree, get_ref=False, is_leaf=is_not_container):
694
- """Flatten ``tree`` into a list of leaves.
695
-
696
- Parameters
697
- ----------
698
- tree : pytree
699
- A nested sequence of tuples, lists, dicts and other objects.
700
- is_leaf : callable
701
- A function to determine if an object is a leaf, only objects for which
702
- ``is_leaf(x)`` returns ``True`` are returned in the flattened list.
703
-
704
- Returns
705
- -------
706
- objs : list
707
- The flattened list of leaf objects.
708
- (ref_tree) : pytree
709
- If ``get_ref`` is ``True``, a reference tree, with leaves of type
710
- ``Leaf``, is returned which can be used to reconstruct the original
711
- tree.
712
- """
713
- objs = []
714
- if get_ref:
715
- # return a new tree with Leaf leaves, as well as the flattened list
716
-
717
- def f(x):
718
- objs.append(x)
719
- return Leaf
720
-
721
- ref_tree = tree_map(f, tree, is_leaf)
722
- return objs, ref_tree
723
- else:
724
- tree_apply(objs.append, tree, is_leaf)
725
- return objs
726
-
727
-
728
- def tree_unflatten(objs, tree, is_leaf=is_leaf_object):
729
- """Unflatten ``objs`` into a pytree of the same structure as ``tree``.
730
-
731
- Parameters
732
- ----------
733
- objs : sequence
734
- A sequence of objects to be unflattened into a pytree.
735
- tree : pytree
736
- A nested sequence of tuples, lists, dicts and other objects, the objs
737
- will be inserted into a new pytree of the same structure.
738
- is_leaf : callable
739
- A function to determine if an object is a leaf, only objects for which
740
- ``is_leaf(x)`` returns ``True`` will have the next item from ``objs``
741
- inserted. By default checks for the ``Leaf`` object inserted by
742
- ``tree_flatten(..., get_ref=True)``.
743
-
744
- Returns
745
- -------
746
- pytree
747
- """
748
- objs = iter(objs)
749
- return tree_map(lambda _: next(objs), tree, is_leaf)
750
-
751
-
752
- def tree_map_tuple(f, tree, is_leaf):
753
- return tuple(tree_map(f, x, is_leaf) for x in tree)
754
-
755
-
756
- def tree_iter_tuple(tree, is_leaf):
757
- for x in tree:
758
- yield from tree_iter(x, is_leaf)
759
-
760
-
761
- def tree_apply_tuple(f, tree, is_leaf):
762
- for x in tree:
763
- tree_apply(f, x, is_leaf)
764
-
765
-
766
- tree_register_container(
767
- tuple, tree_map_tuple, tree_iter_tuple, tree_apply_tuple
768
- )
769
-
770
-
771
- def tree_map_list(f, tree, is_leaf):
772
- return [tree_map(f, x, is_leaf) for x in tree]
773
-
774
-
775
- def tree_iter_list(tree, is_leaf):
776
- for x in tree:
777
- yield from tree_iter(x, is_leaf)
778
-
779
-
780
- def tree_apply_list(f, tree, is_leaf):
781
- for x in tree:
782
- tree_apply(f, x, is_leaf)
783
-
784
-
785
- tree_register_container(list, tree_map_list, tree_iter_list, tree_apply_list)
786
-
787
-
788
- def tree_map_dict(f, tree, is_leaf):
789
- return {k: tree_map(f, v, is_leaf) for k, v in tree.items()}
790
-
791
-
792
- def tree_iter_dict(tree, is_leaf):
793
- for v in tree.values():
794
- yield from tree_iter(v, is_leaf)
795
-
796
-
797
- def tree_apply_dict(f, tree, is_leaf):
798
- for v in tree.values():
799
- tree_apply(f, v, is_leaf)
800
-
801
-
802
- tree_register_container(dict, tree_map_dict, tree_iter_dict, tree_apply_dict)
803
-
804
-
805
- # a style to use for matplotlib that works with light and dark backgrounds
806
- NEUTRAL_STYLE = {
807
- "axes.edgecolor": (0.5, 0.5, 0.5),
808
- "axes.facecolor": (0, 0, 0, 0),
809
- "axes.grid": True,
810
- "axes.labelcolor": (0.5, 0.5, 0.5),
811
- "axes.spines.right": False,
812
- "axes.spines.top": False,
813
- "figure.facecolor": (0, 0, 0, 0),
814
- "grid.alpha": 0.1,
815
- "grid.color": (0.5, 0.5, 0.5),
816
- "legend.frameon": False,
817
- "text.color": (0.5, 0.5, 0.5),
818
- "xtick.color": (0.5, 0.5, 0.5),
819
- "xtick.minor.visible": True,
820
- "ytick.color": (0.5, 0.5, 0.5),
821
- "ytick.minor.visible": True,
822
- }
823
-
824
-
825
- def default_to_neutral_style(fn):
826
- """Wrap a function or method to use the neutral style by default."""
827
-
828
- @functools.wraps(fn)
829
- def wrapper(*args, style="neutral", show_and_close=True, **kwargs):
830
- import matplotlib.pyplot as plt
831
-
832
- if style == "neutral":
833
- style = NEUTRAL_STYLE
834
- elif not style:
835
- style = {}
836
-
837
- with plt.style.context(style):
838
- out = fn(*args, **kwargs)
839
-
840
- if show_and_close:
841
- plt.show()
842
- plt.close()
843
-
844
- return out
845
-
846
- return wrapper
847
-
848
-
849
- def autocorrect_kwargs(func=None, valid_kwargs=None):
850
- """A decorator that suggests the right keyword arguments if you get them
851
- wrong. Useful for functions with many specific options.
852
-
853
- Parameters
854
- ----------
855
- func : callable, optional
856
- The function to decorate.
857
- valid_kwargs : sequence[str], optional
858
- The valid keyword arguments for ``func``, if not given these are
859
- inferred from the function signature.
860
- """
861
- if func is None:
862
- # decorator with options
863
- return functools.partial(autocorrect_kwargs, valid_kwargs=valid_kwargs)
864
-
865
- if valid_kwargs is None:
866
- import inspect
867
-
868
- sig = inspect.signature(func)
869
- params = sig.parameters
870
- valid_kwargs = set(params.keys())
871
- else:
872
- valid_kwargs = set(valid_kwargs)
873
-
874
- @functools.wraps(func)
875
- def wrapped(*args, **kwargs):
876
- wrong_opts = {kw for kw in kwargs if kw not in valid_kwargs}
877
- if wrong_opts:
878
- import difflib
879
-
880
- right_opts = (
881
- difflib.get_close_matches(opt, valid_kwargs, n=3)
882
- for opt in wrong_opts
883
- )
884
- msg = "Option(s) {} not valid.\n Did you mean: {}?".format(
885
- wrong_opts, ", ".join(map(str, right_opts))
886
- )
887
- print(msg)
888
- raise ValueError(msg)
889
-
890
- return func(*args, **kwargs)
891
-
892
- return wrapped