Trajectree 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +9 -9
  5. trajectree/fock_optics/outputs.py +10 -6
  6. trajectree/fock_optics/utils.py +9 -6
  7. trajectree/sequence/swap.py +5 -4
  8. trajectree/trajectory.py +5 -4
  9. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/METADATA +2 -3
  10. trajectree-0.0.3.dist-info/RECORD +16 -0
  11. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  12. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  13. trajectree/quimb/docs/conf.py +0 -158
  14. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  15. trajectree/quimb/quimb/__init__.py +0 -507
  16. trajectree/quimb/quimb/calc.py +0 -1491
  17. trajectree/quimb/quimb/core.py +0 -2279
  18. trajectree/quimb/quimb/evo.py +0 -712
  19. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  20. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  21. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  22. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  23. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  24. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  25. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  26. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  27. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  28. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  29. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  30. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  31. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  32. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  33. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  34. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  35. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  36. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  37. trajectree/quimb/quimb/gates.py +0 -36
  38. trajectree/quimb/quimb/gen/__init__.py +0 -2
  39. trajectree/quimb/quimb/gen/operators.py +0 -1167
  40. trajectree/quimb/quimb/gen/rand.py +0 -713
  41. trajectree/quimb/quimb/gen/states.py +0 -479
  42. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  43. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  44. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  45. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  46. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  47. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  48. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  49. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  50. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  51. trajectree/quimb/quimb/schematic.py +0 -1518
  52. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  53. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  54. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  55. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  56. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  57. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  58. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  59. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  60. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  61. trajectree/quimb/quimb/tensor/interface.py +0 -114
  62. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  63. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  64. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  65. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  66. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  67. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  68. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  69. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  70. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  71. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  72. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  74. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  75. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  76. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  77. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  78. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  79. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  80. trajectree/quimb/quimb/utils.py +0 -892
  81. trajectree/quimb/tests/__init__.py +0 -0
  82. trajectree/quimb/tests/test_accel.py +0 -501
  83. trajectree/quimb/tests/test_calc.py +0 -788
  84. trajectree/quimb/tests/test_core.py +0 -847
  85. trajectree/quimb/tests/test_evo.py +0 -565
  86. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  87. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  88. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  89. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  90. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  91. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  92. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  93. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  94. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  95. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  96. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  97. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  103. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  104. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  105. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  106. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  107. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  108. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  109. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  110. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  111. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  112. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  113. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  114. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  115. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  116. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  117. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  118. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  119. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  120. trajectree/quimb/tests/test_utils.py +0 -85
  121. trajectree-0.0.1.dist-info/RECORD +0 -126
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/WHEEL +0 -0
  123. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/licenses/LICENSE +0 -0
  124. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,1483 +0,0 @@
1
- """Tools for generic VMC optimization of tensor networks.
2
- """
3
-
4
- import array
5
- import random
6
-
7
- import numpy as np
8
- import autoray as ar
9
-
10
- from quimb.utils import default_to_neutral_style
11
- from quimb import format_number_with_error
12
-
13
-
14
- # --------------------------------------------------------------------------- #
15
-
16
-
17
- def sample_bitstring_from_prob_ndarray(p, rng):
18
- flat_idx = rng.choice(np.arange(p.size), p=p.flat)
19
- return np.unravel_index(flat_idx, p.shape)
20
-
21
-
22
- def shuffled(it):
23
- """Return a copy of ``it`` in random order.
24
- """
25
- it = list(it)
26
- random.shuffle(it)
27
- return it
28
-
29
-
30
- class NoContext:
31
- """A convenience context manager that does nothing.
32
- """
33
-
34
- def __enter__(self):
35
- pass
36
-
37
- def __exit__(self, *_, **__):
38
- pass
39
-
40
-
41
- class MovingStatistics:
42
- """Keep track of the windowed mean and estimated variance of a stream of
43
- values on the fly.
44
- """
45
-
46
- def __init__(self, window_size):
47
- self.window_size = window_size
48
- self.xs = []
49
- self.vs = []
50
- self._xsum = 0.0
51
- self._vsum = 0.0
52
-
53
- def update(self, x):
54
-
55
- # update mean
56
- self.xs.append(x)
57
- if len(self.xs) > self.window_size:
58
- xr = self.xs.pop(0)
59
- else:
60
- xr = 0.0
61
- self._xsum += (x - xr)
62
-
63
- # update approx variance
64
- v = (x - self.mean)**2
65
- self.vs.append(v)
66
- if len(self.vs) > self.window_size:
67
- vr = self.vs.pop(0)
68
- else:
69
- vr = 0.0
70
- self._vsum += (v - vr)
71
-
72
- @property
73
- def mean(self) :
74
- N = len(self.xs)
75
- if N == 0:
76
- return 0.0
77
- return self._xsum / N
78
-
79
- @property
80
- def var(self):
81
- N = len(self.xs)
82
- if N == 0:
83
- return 0.0
84
- return self._vsum / N
85
-
86
- @property
87
- def std(self):
88
- return self.var**0.5
89
-
90
- @property
91
- def err(self):
92
- N = len(self.xs)
93
- if N == 0:
94
- return 0.0
95
- return self.std / N**0.5
96
-
97
-
98
- # --------------------------------------------------------------------------- #
99
-
100
- class DenseSampler:
101
- """Sampler that explicitly constructs the full probability distribution.
102
- Useful for debugging small problems.
103
- """
104
-
105
- def __init__(self, psi=None, seed=None, **contract_opts):
106
- if psi is not None:
107
- self._set_psi(psi)
108
- contract_opts.setdefault('optimize', 'auto-hq')
109
- self.contract_opts = contract_opts
110
- self.rng = np.random.default_rng(seed)
111
-
112
- def _set_psi(self, psi):
113
- psi_dense = psi.contract(
114
- ..., output_inds=psi.site_inds, **self.contract_opts,
115
- ).data
116
- self.p = (abs(psi_dense.ravel())**2)
117
- self.p /= self.p.sum()
118
- self.sites = psi.sites
119
- self.shape = tuple(psi.ind_size(ix) for ix in psi.site_inds)
120
- self.flat_indexes = np.arange(self.p.size)
121
-
122
- def sample(self):
123
- flat_idx = self.rng.choice(self.flat_indexes, p=self.p)
124
- omega = self.p[flat_idx]
125
- config = np.unravel_index(flat_idx, self.shape)
126
- return dict(zip(self.sites, config)), omega
127
-
128
- def update(self, **kwargs):
129
- self._set_psi(kwargs['psi'])
130
-
131
-
132
- class DirectTNSampler:
133
- """
134
-
135
- Parameters
136
- ----------
137
- tn : TensorNetwork
138
- The tensor network to sample from.
139
- sweeps : int, optional
140
- The number of sweeps to perform.
141
- max_group_size : int, optional
142
- The maximum number of sites to include in a single marginal.
143
- chi : int, optional
144
- The maximum bond dimension to use for compressed contraction.
145
- optimize : PathOptimizer, optional
146
- The path optimizer to use.
147
- optimize_share_path : bool, optional
148
- If ``True``, a single path will be used for all contractions regardless
149
- of which marginal (i.e. which indices are open) is begin computed.
150
- """
151
-
152
- def __init__(
153
- self,
154
- tn,
155
- sweeps=1,
156
- max_group_size=8,
157
- chi=None,
158
- optimize=None,
159
- optimize_share_path=False,
160
- seed=None,
161
- track=False,
162
- ):
163
- self.tn = tn.copy()
164
-
165
- self.ind2site = {}
166
- self.tid2ind = {}
167
- for site in self.tn.sites:
168
- ix = self.tn.site_ind(site)
169
- tid, = self.tn._get_tids_from_inds(ix)
170
- self.tid2ind[tid] = ix
171
- self.ind2site[ix] = site
172
-
173
- self.chi = chi
174
- self.sweeps = sweeps
175
- self.max_group_size = max_group_size
176
-
177
- self.optimize = optimize
178
- self.optimize_share_path = optimize_share_path
179
- self.groups = None
180
- self.tree = None
181
- self.path = None
182
-
183
- self.rng = np.random.default_rng(seed)
184
-
185
- self.track = track
186
- if self.track:
187
- self.omegas = []
188
- self.probs = []
189
- else:
190
- self.omegas = self.probs = None
191
-
192
- def plot(self,):
193
- from matplotlib import pyplot as plt
194
- fig, ax = plt.subplots(figsize=(4, 4))
195
- mins = min(self.omegas)
196
- maxs = max(self.omegas)
197
- ax.plot([mins, maxs], [mins, maxs], color='red')
198
- ax.scatter(self.probs, self.omegas, marker='.', alpha=0.5)
199
- ax.set_xlabel('$\pi(x)$')
200
- ax.set_ylabel('$\omega(x)$')
201
- ax.set_xscale('log')
202
- ax.set_yscale('log')
203
- ax.grid(True, c=(0.97 ,0.97, 0.97), which='major')
204
- ax.set_axisbelow(True)
205
-
206
- def calc_groups(self, **kwargs):
207
- """Calculate how to group the sites into marginals.
208
- """
209
- self.groups = self.tn.compute_hierarchical_grouping(
210
- max_group_size=self.max_group_size,
211
- tids=tuple(self.tid2ind),
212
- **kwargs,
213
- )
214
-
215
- def get_groups(self):
216
- if self.groups is None:
217
- self.calc_groups()
218
- return self.groups
219
-
220
- def calc_path(self):
221
- tn0 = self.tn.isel({ix: 0 for ix in self.ind2site})
222
- self.tree = tn0.contraction_tree(self.optimize)
223
- self.path = self.tree.get_path()
224
-
225
- def get_path(self):
226
- if self.path is None:
227
- self.calc_path()
228
- return self.path
229
-
230
- def get_optimize(self):
231
- if self.optimize_share_path:
232
- return self.get_path()
233
- else:
234
- return self.optimize
235
-
236
- def contract(self, tn, output_inds):
237
- if self.chi is None:
238
- return tn.contract(
239
- optimize=self.get_optimize(),
240
- output_inds=output_inds,
241
- )
242
- else:
243
- return tn.contract_compressed(
244
- max_bond=self.chi,
245
- optimize=self.get_optimize(),
246
- output_inds=output_inds,
247
- cutoff=0.0,
248
- compress_opts=dict(absorb='both'),
249
- )
250
-
251
- def sample(self):
252
-
253
- config = {}
254
-
255
- tnm = self.tn.copy()
256
-
257
- for tid, ix in self.tid2ind.items():
258
- t = tnm.tensor_map[tid]
259
- t.rand_reduce_(
260
- ix, rand_fn=lambda d: self.rng.choice([-1.0, 1.0], size=d)
261
- )
262
-
263
- tnm.apply_to_arrays(ar.lazy.array)
264
- with ar.lazy.shared_intermediates():
265
- # with NoContext():
266
-
267
- for _ in range(self.sweeps):
268
-
269
- # random.shuffle(self.groups)
270
- omega = 1.0
271
-
272
- for group in self.get_groups():
273
- # get corresponding indices
274
- inds = [self.tid2ind[tid] for tid in group]
275
-
276
- # insert the orig tensors with output index
277
- for tid in group:
278
- t_full = self.tn.tensor_map[tid]
279
- tnm.tensor_map[tid].modify(
280
- data=ar.lazy.array(t_full.data),
281
- inds=t_full.inds,
282
- )
283
-
284
- # contract the current conditional marginal
285
- tg = self.contract(tnm, inds)
286
-
287
- # convert into normalized prob and sample a config
288
- prob_g = ar.do('abs', tg.data.compute())**2
289
- prob_g /= ar.do('sum', prob_g)
290
- config_g = sample_bitstring_from_prob_ndarray(
291
- prob_g, self.rng
292
- )
293
- omega *= prob_g[config_g]
294
-
295
- # re-project the tensors according to the sampled config
296
- for tid, ix, bi in zip(group, inds, config_g):
297
-
298
- # the 'measurement' for this tensor
299
- # bi = int(bi)
300
-
301
- # project tensor from full wavefunction
302
- t_full = self.tn.tensor_map[tid]
303
- tm = t_full.isel({ix: bi})
304
- tnm.tensor_map[tid].modify(
305
- data=ar.lazy.array(tm.data),
306
- inds=tm.inds,
307
- )
308
-
309
- # update the bitstring
310
- config[self.ind2site[ix]] = bi
311
-
312
- if self.track:
313
- self.omegas.append(omega)
314
- self.probs.append(abs(tg.data[config_g].compute())**2)
315
-
316
- # final chosen marginal is prob of whole config
317
- return config, omega # tg.data[config_g].compute()
318
-
319
-
320
- def compute_amplitude(tn, config, chi, optimize):
321
- tni = tn.isel({tn.site_ind(site): v for site, v in config.items()})
322
- return tni.contract_compressed_(
323
- optimize=optimize,
324
- max_bond=chi,
325
- cutoff=0.0,
326
- compress_opts={'absorb': 'both'},
327
- inplace=True,
328
- )
329
-
330
-
331
- def compute_amplitudes(tn, configs, chi, optimize):
332
- with ar.lazy.shared_intermediates():
333
- tnlz = tn.copy()
334
- tnlz.apply_to_arrays(ar.lazy.array)
335
-
336
- amps = []
337
- for config in configs:
338
- amps.append(compute_amplitude(tnlz, config, chi, optimize))
339
-
340
- amps = ar.do('stack', amps)
341
- return amps.compute()
342
-
343
-
344
- def compute_local_energy(ham, tn, config, chi, optimize):
345
- """
346
- """
347
- c_configs, c_coeffs = ham.config_coupling(config)
348
- amps = compute_amplitudes(tn, [config] + c_configs, chi, optimize)
349
- c_coeffs = ar.do('array', c_coeffs, like=amps)
350
- return ar.do('sum', amps[1:] * c_coeffs) / amps[0]
351
-
352
-
353
- def draw_config(edges, config):
354
- import networkx as nx
355
- G = nx.Graph(edges)
356
- pos = nx.kamada_kawai_layout(G)
357
- nx.draw(G, node_color=[config[node] for node in G.nodes], pos=pos)
358
-
359
-
360
- class ClusterSampler:
361
-
362
- def __init__(
363
- self,
364
- psi=None,
365
- max_distance=1,
366
- use_gauges=True,
367
- seed=None,
368
- contract_opts=None,
369
- ):
370
- self.rng = np.random.default_rng(seed)
371
- self.use_gauges = use_gauges
372
- self.max_distance = max_distance
373
- self.contract_opts = (
374
- {} if contract_opts is None else dict(contract_opts)
375
- )
376
- self.contract_opts.setdefault('optimize', 'auto-hq')
377
- if psi is not None:
378
- self._set_psi(psi)
379
-
380
- def _set_psi(self, psi):
381
- self._psi = psi.copy()
382
- if self.use_gauges:
383
- self.gauges0 = {}
384
- self._psi.gauge_all_simple_(gauges=self.gauges0)
385
- else:
386
- self.gauges0 = None
387
-
388
- self.tid2site = {}
389
- for site in self._psi.sites:
390
- tid, = self._psi._get_tids_from_tags(site)
391
- self.tid2site[tid] = site
392
- self.ordering = self._psi.compute_hierarchical_ordering()
393
-
394
- def sample(self):
395
- """
396
- """
397
- config = {}
398
- psi = self._psi.copy()
399
-
400
- if self.use_gauges:
401
- gauges = self.gauges0.copy()
402
- else:
403
- gauges = None
404
-
405
- omega = 1.0
406
-
407
- for tid in self.ordering:
408
- site = self.tid2site[tid]
409
- ind = psi.site_ind(site)
410
-
411
- # select a local patch
412
- k = psi._select_local_tids(
413
- [tid],
414
- max_distance=self.max_distance,
415
- fillin=0,
416
- virtual=False,
417
- )
418
-
419
- if self.use_gauges:
420
- # gauge it including dangling bonds
421
- k.gauge_simple_insert(gauges)
422
-
423
- # contract the approx reduced density matrix diagonal
424
- pk = (k.H & k).contract(
425
- ...,
426
- output_inds=[ind], # directly extract diagonal
427
- **self.contract_opts,
428
- ).data
429
-
430
- # normalize and sample a state for this site
431
- pk /= pk.sum()
432
- idx = self.rng.choice(np.arange(2), p=pk)
433
- config[site] = idx
434
-
435
- # track the probability chain
436
- omega *= pk[idx]
437
-
438
- # fix the site to measurement
439
- psi.tensor_map[tid].isel_({ind: idx})
440
-
441
- if self.use_gauges:
442
- # update local gauges to take measurement into account
443
- psi._gauge_local_tids(
444
- [tid],
445
- max_distance=(self.max_distance + 1),
446
- method='simple', gauges=gauges
447
- )
448
-
449
- return config, omega
450
-
451
- candidate = sample
452
-
453
- def accept(self, config):
454
- pass
455
-
456
- def update(self, **kwargs):
457
- self._set_psi(kwargs['psi'])
458
-
459
-
460
- class ExchangeSampler:
461
-
462
- def __init__(self, edges, seed=None):
463
- self.edges = tuple(sorted(edges))
464
- self.Ne = len(self.edges)
465
- self.sites = sorted(set(site for edge in edges for site in edge))
466
- self.N = len(self.sites)
467
- self.rng = np.random.default_rng(seed)
468
- values0 = [0] * (self.N // 2) + [1] * (self.N // 2)
469
- if self.N % 2 == 1:
470
- values0.append(0)
471
- values0 = self.rng.permutation(values0)
472
- self.config = dict(zip(self.sites, values0))
473
-
474
- def candidate(self):
475
- nconfig = self.config.copy()
476
- for i in self.rng.permutation(np.arange(self.Ne)):
477
- cooa, coob = self.edges[i]
478
- xa, xb = nconfig[cooa], nconfig[coob]
479
- if xa == xb:
480
- continue
481
- nconfig[cooa], nconfig[coob] = xb, xa
482
- return nconfig, 1.0
483
-
484
- def accept(self, config):
485
- self.config = config
486
-
487
- def sample(self):
488
- config, omega = self.candidate()
489
- self.accept(config)
490
- return config, omega
491
-
492
- def update(self, **_):
493
- pass
494
-
495
-
496
- class HamiltonianSampler:
497
-
498
- def __init__(self, ham, seed=None):
499
- self.ham = ham
500
- self.rng = np.random.default_rng(seed)
501
-
502
- self.N = len(self.ham.sites)
503
- values0 = [0] * (self.N // 2) + [1] * (self.N // 2)
504
- if self.N % 2 == 1:
505
- values0.append(0)
506
- values0 = self.rng.permutation(values0)
507
- self.config = dict(zip(self.ham.sites, values0))
508
-
509
- def candidate(self):
510
- generate = True
511
- while generate:
512
- # XXX: could do this much more efficiently with a single random
513
- # term
514
- configs, _ = self.ham.config_coupling(self.config)
515
- i = self.rng.integers(len(configs))
516
- new_config = configs[i]
517
- generate = (new_config == self.config)
518
- return new_config, 1.0
519
-
520
- def accept(self, config):
521
- self.config = config
522
-
523
- def sample(self):
524
- config, omega = self.candidate()
525
- self.accept(config)
526
- return config, omega
527
-
528
- def update(self, **_):
529
- pass
530
-
531
-
532
- class MetropolisHastingsSampler:
533
- """
534
- """
535
-
536
- def __init__(
537
- self,
538
- sub_sampler,
539
- amplitude_factory=None,
540
- initial=None,
541
- burn_in=0,
542
- seed=None,
543
- track=False,
544
- ):
545
- self.sub_sampler = sub_sampler
546
-
547
- if amplitude_factory is not None:
548
- self.prob_fn = amplitude_factory.prob
549
- else:
550
- # will initialize later
551
- self.prob_fn = None
552
-
553
- if initial is not None:
554
- self.config, self.omega, self.prob = initial
555
- else:
556
- self.config = self.omega = self.prob = None
557
-
558
- self.seed = seed
559
- self.rng = np.random.default_rng(self.seed)
560
- self.accepted = 0
561
- self.total = 0
562
- self.burn_in = burn_in
563
-
564
- # should we record the history?
565
- self.track = track
566
- if self.track:
567
- self.omegas = array.array('d')
568
- self.probs = array.array('d')
569
- self.acceptances = array.array('d')
570
- else:
571
- self.omegas = self.probs = self.acceptances = None
572
-
573
- @property
574
- def acceptance_ratio(self):
575
- if self.total == 0:
576
- return 0.0
577
- return self.accepted / self.total
578
-
579
- def sample(self):
580
- if self.config is None:
581
- # check if we are starting from scratch
582
- self.config, self.omega = self.sub_sampler.sample()
583
- self.prob = self.prob_fn(self.config)
584
-
585
- while True:
586
- self.total += 1
587
-
588
- # generate candidate configuration
589
- nconfig, nomega = self.sub_sampler.candidate()
590
- nprob = self.prob_fn(nconfig)
591
-
592
- # compute acceptance probability
593
- acceptance = (nprob * self.omega) / (self.prob * nomega)
594
-
595
- if self.track:
596
- self.omegas.append(nomega)
597
- self.probs.append(nprob)
598
- self.acceptances.append(acceptance)
599
-
600
- if (self.rng.uniform() < acceptance):
601
- self.config = nconfig
602
- self.omega = nomega
603
- self.prob = nprob
604
- self.accepted += 1
605
- self.sub_sampler.accept(nconfig)
606
-
607
- if (self.total > self.burn_in):
608
- return self.config, self.omega
609
-
610
- def update(self, **kwargs):
611
- self.prob_fn = kwargs['amplitude_factory'].prob
612
- self.sub_sampler.update(**kwargs)
613
-
614
- @default_to_neutral_style
615
- def plot(self):
616
- from matplotlib import pyplot as plt
617
-
618
- fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
619
- fig.suptitle(f"acceptance ratio = {100 * self.acceptance_ratio:.2f} %")
620
-
621
- mins = min(self.omegas)
622
- maxs = max(self.omegas)
623
-
624
- axs[0].plot([mins, maxs], [mins, maxs], color="red")
625
- axs[0].scatter(
626
- self.probs, self.omegas, marker=".", alpha=0.5, zorder=-10
627
- )
628
- axs[0].set_rasterization_zorder(0)
629
- axs[0].set_xlabel("$\pi(x)$")
630
- axs[0].set_ylabel("$\omega(x)$")
631
- axs[0].set_xscale("log")
632
- axs[0].set_yscale("log")
633
- axs[0].grid(True, c=(0.97, 0.97, 0.97), which="major")
634
- axs[0].set_axisbelow(True)
635
-
636
- minh = np.log10(min(self.acceptances))
637
- maxh = np.log10(max(self.acceptances))
638
- axs[1].hist(
639
- self.acceptances, bins=np.logspace(minh, maxh), color="green"
640
- )
641
- axs[1].set_xlabel("$A = \dfrac{\pi(x)\omega(y)}{\pi(y)\omega(x)}$")
642
- axs[1].axvline(1.0, color="orange")
643
- axs[1].set_xscale("log")
644
- axs[1].grid(True, c=(0.97, 0.97, 0.97), which="major")
645
- axs[1].set_axisbelow(True)
646
-
647
- return fig, axs
648
-
649
-
650
- # --------------------------------------------------------------------------- #
651
-
652
-
653
- def auto_share_multicall(func, arrays, configs):
654
- """Call the function ``func``, which should be an array
655
- function making use of autoray dispatched calls, multiple
656
- times, automatically reusing shared intermediates.
657
- """
658
- with ar.lazy.shared_intermediates():
659
- lzarrays_all = [
660
- # different variants provided as first dimension
661
- ar.lazy.array(x) if hasattr(x, 'shape') else
662
- # different variants provided as a sequence
663
- list(map(ar.lazy.array, x))
664
- for x in arrays
665
- ]
666
- lzarrays_config = lzarrays_all.copy()
667
-
668
- outs = []
669
- for config in configs:
670
- # for each config, insert the correct inputs
671
- for k, v in config.items():
672
- lzarrays_config[k] = lzarrays_all[k][v]
673
- # evaluate the function
674
- outs.append(func(lzarrays_config))
675
-
676
- # combine into single output object
677
- final = ar.lazy.stack(tuple(outs))
678
-
679
- # evaluate all configs simulteneously
680
- return final.compute()
681
-
682
-
683
- class ComposePartial:
684
-
685
- __slots__ = (
686
- "f",
687
- "f_args",
688
- "f_kwargs",
689
- "g",
690
- )
691
-
692
- def __init__(self, f, f_args, f_kwargs, g):
693
- self.f = f
694
- self.f_args = f_args
695
- self.f_kwargs = f_kwargs
696
- self.g = g
697
-
698
- def __call__(self, *args, **kwargs):
699
- y = self.g(*args, **kwargs)
700
- f_args = (
701
- y if isinstance(v, ar.lazy.LazyArray) else v
702
- for v in self.f_args
703
- )
704
- return self.f(*f_args, **self.f_kwargs)
705
-
706
-
707
- _partial_compose_cache = {}
708
-
709
-
710
- def get_compose_partial(f, f_args, f_kwargs, g):
711
-
712
- key = (
713
- f,
714
- tuple(
715
- '__placeholder__' if isinstance(v, ar.lazy.LazyArray)
716
- else v
717
- for v in f_args
718
- ),
719
- tuple(sorted(f_kwargs.items())),
720
- g,
721
- )
722
-
723
- try:
724
- fg = _partial_compose_cache[key]
725
- except KeyError:
726
- fg = _partial_compose_cache[key] = ComposePartial(f, f_args, f_kwargs, g)
727
- except TypeError:
728
- fg = ComposePartial(f, f_args, f_kwargs, g)
729
-
730
- return fg
731
-
732
-
733
- def fuse_unary_ops_(Z):
734
- queue = [Z]
735
- seen = set()
736
- while queue:
737
- node = queue.pop()
738
- if (
739
- len(node._deps) == 1 and
740
- any(isinstance(v, ar.lazy.LazyArray) for v in node.args)
741
- ):
742
- dep, = node._deps
743
- if dep._nchild == 1 and dep._fn:
744
- node._fn = get_compose_partial(node._fn, node._args, node._kwargs, dep._fn)
745
- node._args = dep._args
746
- node._kwargs = dep._kwargs
747
- node._deps = dep._deps
748
- queue.append(node)
749
- continue
750
-
751
- for dep in node._deps:
752
- if dep not in seen:
753
- queue.append(dep)
754
- seen.add(dep)
755
-
756
-
757
-
758
- class AmplitudeFactory:
759
-
760
- def __init__(
761
- self,
762
- psi=None,
763
- contract_fn=None,
764
- maxsize=2**20,
765
- autojit_opts=(),
766
- **contract_opts,
767
- ):
768
- from quimb.utils import LRU
769
-
770
- self.contract_fn = contract_fn
771
- self.contract_opts = contract_opts
772
- if self.contract_opts.get('max_bond', None) is not None:
773
- self.contract_opts.setdefault('cutoff', 0.0)
774
-
775
- self.autojit_opts = dict(autojit_opts)
776
-
777
- if psi is not None:
778
- self._set_psi(psi)
779
-
780
- self.store = LRU(maxsize=maxsize)
781
- self.hits = 0
782
- self.queries = 0
783
-
784
- def _set_psi(self, psi):
785
- psi0 = psi.copy()
786
-
787
- self.arrays = []
788
- self.sitemap = {}
789
- variables = []
790
-
791
- for site in psi0.sites:
792
- ix = psi0.site_ind(site)
793
- t, = psi0._inds_get(ix)
794
-
795
- # want variable index first
796
- t.moveindex_(ix, 0)
797
- self.sitemap[site] = len(self.arrays)
798
- self.arrays.append(t.data)
799
-
800
- # insert lazy variable for sliced tensor
801
- variable = ar.lazy.Variable(t.shape[1:], backend='autoray.lazy')
802
- variables.append(variable)
803
- t.modify(data=variable, inds=t.inds[1:])
804
-
805
- # trace the function lazily
806
- if self.contract_fn is None:
807
- Z = psi0.contract(..., output_inds=(), **self.contract_opts)
808
- else:
809
- Z = self.contract_fn(psi0, **self.contract_opts)
810
-
811
- # get the functional form of this traced contraction
812
- self.f_lazy = Z.get_function(variables)
813
-
814
- # this can then itself be traced with concrete arrays
815
- self.f = ar.autojit(self.f_lazy, **self.autojit_opts)
816
-
817
- def compute_single(self, config):
818
- """Compute the amplitude of ``config``, making use of autojit.
819
- """
820
- arrays = self.arrays.copy()
821
- for site, v in config.items():
822
- i = self.sitemap[site]
823
- arrays[i] = self.arrays[i][v]
824
- return self.f(arrays)
825
-
826
- def compute_multi(self, configs):
827
- """Compute the amplitudes corresponding to the sequence ``configs``,
828
- making use of shared intermediates.
829
- """
830
- # translate index config to position configs
831
- iconfigs = [
832
- {self.sitemap[site]: v for site, v in config.items()}
833
- for config in configs
834
- ]
835
- return auto_share_multicall(self.f_lazy, self.arrays, iconfigs)
836
-
837
- # def update(self, config, coeff):
838
- # """Update the amplitude cache with a new configuration.
839
- # """
840
- # self.store[tuple(sorted(config.items()))] = coeff
841
-
842
- def amplitude(self, config):
843
- """Get the amplitude of ``config``, either from the cache or by
844
- computing it.
845
- """
846
- key = tuple(sorted(config.items()))
847
- self.queries += 1
848
- if key in self.store:
849
- self.hits += 1
850
- return self.store[key]
851
-
852
- coeff = self.compute_single(self.psi, config)
853
-
854
- self.store[key] = coeff
855
- return coeff
856
-
857
- def amplitudes(self, configs):
858
- """
859
- """
860
- # first parse out the configurations we need to compute
861
- all_keys = []
862
- new_keys = []
863
- new_configs = []
864
- for config in configs:
865
- key = tuple(sorted(config.items()))
866
- all_keys.append(key)
867
- self.queries += 1
868
- if key in self.store:
869
- self.hits += 1
870
- else:
871
- new_keys.append(key)
872
- new_configs.append(config)
873
-
874
- # compute the new configurations
875
- if new_configs:
876
- new_coeffs = self.compute_multi(new_configs)
877
- for key, coeff in zip(new_keys, new_coeffs):
878
- self.store[key] = coeff
879
-
880
- # return the full set of old and new coefficients
881
- return [self.store[key] for key in all_keys]
882
-
883
- def prob(self, config):
884
- """Calculate the probability of a configuration.
885
- """
886
- coeff = self.amplitude(config)
887
- return ar.do("abs", coeff)**2
888
-
889
- def clear(self):
890
- self.store.clear()
891
-
892
- def __contains__(self, config):
893
- return tuple(sorted(config.items())) in self.store
894
-
895
- def __setitem__(self, config, c):
896
- self.store[tuple(sorted(config.items()))] = c
897
-
898
- def __getitem__(self, config):
899
- return self.amplitude(config)
900
-
901
- def __repr__(self):
902
- return (
903
- f"<{self.__class__.__name__}(hits={self.hits}, "
904
- f"queries={self.queries})>"
905
- )
906
-
907
-
908
- # class AmplitudeStore:
909
-
910
- # def __init__(self, psi, amp_fn, maxsize=2**20):
911
- # from quimb.utils import LRU
912
- # self.psi = psi
913
- # self.amp_fn = amp_fn
914
- # self.store = LRU(maxsize=maxsize)
915
- # self.hits = 0
916
- # self.queries = 0
917
-
918
- # def update(self, config, coeff):
919
- # """Update the amplitude cache with a new configuration.
920
- # """
921
- # self.store[tuple(sorted(config.items()))] = coeff
922
-
923
- # def amplitude(self, config):
924
- # """Calculate the amplitude of a configuration.
925
- # """
926
- # self.queries += 1
927
-
928
- # key = tuple(sorted(config.items()))
929
- # if key in self.store:
930
- # self.hits += 1
931
- # return self.store[key]
932
-
933
- # amp = self.amp_fn(self.psi, config)
934
-
935
- # self.store[key] = amp
936
- # return amp
937
-
938
- # def prob(self, config):
939
- # """Calculate the probability of a configuration.
940
- # """
941
- # amp = self.amplitude(config)
942
- # return ar.do("abs", amp)**2
943
-
944
- # def clear(self):
945
- # self.store.clear()
946
-
947
- # def __contains__(self, config):
948
- # return tuple(sorted(config.items())) in self.store
949
-
950
- # def __setitem__(self, config, c):
951
- # self.store[tuple(sorted(config.items()))] = c
952
-
953
- # def __getitem__(self, config):
954
- # return self.amplitude(config)
955
-
956
- # def __repr__(self):
957
- # return (
958
- # f"<{self.__class__.__name__}(hits={self.hits}, "
959
- # f"queries={self.queries})>"
960
- # )
961
-
962
-
963
- # @autojit_tn(backend='torch', check_inputs=False)
964
- # def contract_amplitude_tn(tni, **contract_opts):
965
- # if 'max_bond' in contract_opts:
966
- # contract_opts.setdefault('cutoff', 0.0)
967
- # return tni.contract(..., output_inds=(), **contract_opts)
968
-
969
-
970
- # def compute_amplitude_aj(tn, config, **contract_opts):
971
- # tni = tn.isel({tn.site_ind(site): v for site, v in config.items()})
972
- # return contract_amplitude_tn(tni, **contract_opts)
973
-
974
-
975
- # def compute_local_energy_aj(ham, config, amp):
976
- # en = 0.0
977
- # c_configs, c_coeffs = ham.config_coupling(config)
978
- # cx = amp.amplitude(config)
979
- # for hxy, config_y in zip(c_coeffs, c_configs):
980
- # cy = amp.amplitude(config_y)
981
- # en += hxy * cy / cx
982
- # return en
983
-
984
-
985
- # def compute_amp_and_gradients(psi, config, **contract_opts):
986
- # import torch
987
- # psi_t = psi.copy()
988
- # psi_t.apply_to_arrays(lambda x: torch.tensor(x).requires_grad_())
989
- # c = compute_amplitude_aj(psi_t, config, **contract_opts)
990
- # c.backward()
991
- # c = c.item()
992
- # return [t.data.grad.numpy() / c for t in psi_t], c
993
-
994
-
995
- # --------------------------------------------------------------------------- #
996
-
997
-
998
- class GradientAccumulator:
999
-
1000
- def __init__(self):
1001
- self._grads_logpsi = None
1002
- self._grads_energy = None
1003
- self._batch_energy = None
1004
- self._num_samples = 0
1005
-
1006
- def _init_storage(self, grads):
1007
- self._batch_energy = 0.0
1008
- self._grads_logpsi = [np.zeros_like(g) for g in grads]
1009
- self._grads_energy = [np.zeros_like(g) for g in grads]
1010
-
1011
- def update(self, grads_logpsi_sample, local_energy):
1012
- if self._batch_energy is None:
1013
- self._init_storage(grads_logpsi_sample)
1014
-
1015
- self._batch_energy += local_energy
1016
- for g, ge, g_i in zip(
1017
- self._grads_logpsi, self._grads_energy, grads_logpsi_sample,
1018
- ):
1019
- g += g_i
1020
- ge += g_i * local_energy
1021
- self._num_samples += 1
1022
-
1023
- def extract_grads_energy(self):
1024
- e = self._batch_energy / self._num_samples
1025
- grads_energy_batch = []
1026
- for g, ge in zip(self._grads_logpsi, self._grads_energy):
1027
- g /= self._num_samples
1028
- ge /= self._num_samples
1029
- grads_energy_batch.append(ge - g * e)
1030
- # reset storage
1031
- g.fill(0.0)
1032
- ge.fill(0.0)
1033
- self._batch_energy = 0.0
1034
- self._num_samples = 0
1035
- return grads_energy_batch
1036
-
1037
-
1038
- class SGD(GradientAccumulator):
1039
-
1040
- def __init__(self, learning_rate=0.01):
1041
- self.learning_rate = learning_rate
1042
- super().__init__()
1043
-
1044
- def transform_gradients(self):
1045
- return [
1046
- self.learning_rate * g
1047
- for g in self.extract_grads_energy()
1048
- ]
1049
-
1050
-
1051
- class SignDescent(GradientAccumulator):
1052
-
1053
- def __init__(self, learning_rate=0.01):
1054
- self.learning_rate = learning_rate
1055
- super().__init__()
1056
-
1057
- def transform_gradients(self):
1058
- return [
1059
- self.learning_rate * np.sign(g)
1060
- for g in self.extract_grads_energy()
1061
- ]
1062
-
1063
-
1064
- class RandomSign(GradientAccumulator):
1065
-
1066
- def __init__(self, learning_rate=0.01):
1067
- self.learning_rate = learning_rate
1068
- super().__init__()
1069
-
1070
- def transform_gradients(self):
1071
- return [
1072
- self.learning_rate * np.sign(g) * np.random.uniform(size=g.shape)
1073
- for g in self.extract_grads_energy()
1074
- ]
1075
-
1076
-
1077
- class Adam(GradientAccumulator):
1078
-
1079
- def __init__(
1080
- self,
1081
- learning_rate=0.01,
1082
- beta1=0.9,
1083
- beta2=0.999,
1084
- eps=1e-8,
1085
- ):
1086
- self.learning_rate = learning_rate
1087
- self.beta1 = beta1
1088
- self.beta2 = beta2
1089
- self.eps = eps
1090
- self._num_its = 0
1091
- self._ms = None
1092
- self._vs = None
1093
- super().__init__()
1094
-
1095
- def transform_gradients(self):
1096
- # get the standard SGD gradients
1097
- grads = self.extract_grads_energy()
1098
-
1099
- self._num_its += 1
1100
- if self._num_its == 1:
1101
- # first iteration, initialize storage
1102
- self._ms = [np.zeros_like(g) for g in grads]
1103
- self._vs = [np.zeros_like(g) for g in grads]
1104
-
1105
- deltas = []
1106
- for i, g in enumerate(grads):
1107
- # first moment estimate
1108
- m = (1 - self.beta1) * g + self.beta1 * self._ms[i]
1109
- # second moment estimate
1110
- v = (1 - self.beta2) * (g**2) + self.beta2 * self._vs[i]
1111
- # bias correction
1112
- mhat = m / (1 - self.beta1**(self._num_its))
1113
- vhat = v / (1 - self.beta2**(self._num_its))
1114
- deltas.append(
1115
- self.learning_rate * mhat / (np.sqrt(vhat) + self.eps)
1116
- )
1117
- return deltas
1118
-
1119
-
1120
- from quimb.tensor.optimize import Vectorizer
1121
-
1122
-
1123
- class StochasticReconfigureGradients:
1124
-
1125
- def __init__(self, delta=1e-5):
1126
- self.delta = delta
1127
- self.vectorizer = None
1128
- self.gs = []
1129
-
1130
- def update(self, grads_logpsi_sample, local_energy):
1131
- if self.vectorizer is None:
1132
- # first call, initialize storage
1133
- self.vectorizer = Vectorizer(grads_logpsi_sample)
1134
- self.gs.append(self.vectorizer.pack(grads_logpsi_sample).copy())
1135
- super().update(grads_logpsi_sample, local_energy)
1136
-
1137
- def extract_grads_energy(self):
1138
- # number of samples
1139
- num_samples = len(self.gs)
1140
-
1141
- gs = np.stack(self.gs)
1142
- self.gs.clear()
1143
- # <g_i g_j>
1144
- S = (gs.T / num_samples) @ gs
1145
- # minus <g_i><g_j> to get S
1146
- g = gs.sum(axis=0) / num_samples
1147
- S -= np.outer(g, g)
1148
-
1149
- # condition by adding to diagonal
1150
- S.flat[::S.shape[0] + 1] += self.delta
1151
-
1152
- # the uncorrected energy gradient / 'force' vector
1153
- y = self.vectorizer.pack(super().extract_grads_energy())
1154
-
1155
- # the corrected energy gradient, which we then unvectorize
1156
- x = np.linalg.solve(S, y)
1157
- return self.vectorizer.unpack(x)
1158
-
1159
-
1160
-
1161
- class SR(SGD, StochasticReconfigureGradients):
1162
-
1163
- def __init__(self, learning_rate=0.05, delta=1e-5):
1164
- StochasticReconfigureGradients.__init__(self, delta=delta)
1165
- SGD().__init__(self, learning_rate=learning_rate)
1166
-
1167
-
1168
-
1169
- class SRADAM(Adam, StochasticReconfigureGradients):
1170
-
1171
- def __init__(
1172
- self,
1173
- learning_rate=0.01,
1174
- beta1=0.9,
1175
- beta2=0.999,
1176
- eps=1e-8,
1177
- delta=1e-5,
1178
- ):
1179
- StochasticReconfigureGradients.__init__(self, delta=delta)
1180
- Adam.__init__(
1181
- self, learning_rate=learning_rate,
1182
- beta1=beta1, beta2=beta2, eps=eps,
1183
- )
1184
-
1185
-
1186
- # --------------------------------------------------------------------------- #
1187
-
1188
- class TNVMC:
1189
-
1190
- def __init__(
1191
- self,
1192
- psi,
1193
- ham,
1194
- sampler,
1195
- conditioner='auto',
1196
- learning_rate=1e-2,
1197
- optimizer='adam',
1198
- optimizer_opts=None,
1199
- track_window_size=1000,
1200
- **contract_opts
1201
- ):
1202
- from quimb.utils import ensure_dict
1203
-
1204
- self.psi = psi.copy()
1205
- self.ham = ham
1206
- self.sampler = sampler
1207
-
1208
- if conditioner == 'auto':
1209
-
1210
- def conditioner(psi):
1211
- psi.equalize_norms_(1.0)
1212
-
1213
- else:
1214
- self.conditioner = conditioner
1215
-
1216
- if self.conditioner is not None:
1217
- # want initial arrays to be in conditioned form so that gradients
1218
- # are approximately consistent across runs (e.g. for momentum)
1219
- self.conditioner(self.psi)
1220
-
1221
- optimizer_opts = ensure_dict(optimizer_opts)
1222
- self.optimizer = {
1223
- 'adam': Adam,
1224
- 'sgd': SGD,
1225
- 'sign': SignDescent,
1226
- 'signu': RandomSign,
1227
- 'sr': SR,
1228
- 'sradam': SRADAM,
1229
- }[optimizer.lower()](learning_rate=learning_rate, **optimizer_opts)
1230
- self.contract_opts = contract_opts
1231
-
1232
- self.amplitude_factory = AmplitudeFactory(self.psi, **contract_opts)
1233
- self.sampler.update(psi=self.psi, amplitude_factory=self.amplitude_factory)
1234
-
1235
- # tracking information
1236
- self.moving_stats = MovingStatistics(track_window_size)
1237
- self.local_energies = array.array('d')
1238
- self.energies = array.array('d')
1239
- self.energy_errors = array.array('d')
1240
- self.num_tensors = self.psi.num_tensors
1241
- self.nsites = self.psi.nsites
1242
- self._progbar = None
1243
-
1244
- def _compute_log_gradients_torch(self, config):
1245
- import torch
1246
- psi_t = self.psi.copy()
1247
- psi_t.apply_to_arrays(lambda x: torch.tensor(x).requires_grad_())
1248
- c = self.amplitude_factory(psi_t, config)
1249
- c.backward()
1250
- c = c.item()
1251
- self.amplitude_factory[config] = c
1252
- return [t.data.grad.numpy() / c for t in psi_t]
1253
-
1254
- def _compute_local_energy(self, config):
1255
- en = 0.0
1256
- c_configs, c_coeffs = self.ham.config_coupling(config)
1257
- cx = self.amplitude_factory.amplitude(config)
1258
- for hxy, config_y in zip(c_coeffs, c_configs):
1259
- cy = self.amplitude_factory.amplitude(config_y)
1260
- en += hxy * cy / cx
1261
- return en / self.nsites
1262
-
1263
- def _run(self, steps, batchsize):
1264
- for _ in range(steps):
1265
- for _ in range(batchsize):
1266
- config, omega = self.sampler.sample()
1267
-
1268
- # compute and track local energy
1269
- local_energy = self._compute_local_energy(config)
1270
- self.local_energies.append(local_energy)
1271
- self.moving_stats.update(local_energy)
1272
- self.energies.append(self.moving_stats.mean)
1273
- self.energy_errors.append(self.moving_stats.err)
1274
-
1275
- # compute the sample log amplitude gradients
1276
- grads_logpsi_sample = self._compute_log_gradients_torch(config)
1277
-
1278
- self.optimizer.update(grads_logpsi_sample, local_energy)
1279
-
1280
- if self._progbar is not None:
1281
- self._progbar.update()
1282
- self._progbar.set_description(
1283
- format_number_with_error(
1284
- self.moving_stats.mean,
1285
- self.moving_stats.err))
1286
-
1287
- # apply learning rate and other transforms to gradients
1288
- deltas = self.optimizer.transform_gradients()
1289
-
1290
- # update the actual tensors
1291
- for t, delta in zip(self.psi.tensors, deltas):
1292
- t.modify(data=t.data - delta)
1293
-
1294
- # reset having just performed a gradient step
1295
- if self.conditioner is not None:
1296
- self.conditioner(self.psi)
1297
-
1298
- self.amplitude_factory.clear()
1299
- self.sampler.update(psi=self.psi, amplitude_factory=self.amplitude_factory)
1300
-
1301
- def run(
1302
- self,
1303
- total=10_000,
1304
- batchsize=100,
1305
- progbar=True,
1306
- ):
1307
- steps = total // batchsize
1308
- total = steps * batchsize
1309
-
1310
- if progbar:
1311
- from quimb.utils import progbar as Progbar
1312
- self._progbar = Progbar(total=total)
1313
-
1314
- try:
1315
- self._run(steps, batchsize)
1316
- except KeyboardInterrupt:
1317
- pass
1318
- finally:
1319
- if self._progbar is not None:
1320
- self._progbar.close()
1321
-
1322
- def measure(
1323
- self,
1324
- max_samples=10_000,
1325
- rtol=1e-4,
1326
- progbar=True,
1327
- ):
1328
- from xyzpy import RunningStatistics
1329
-
1330
- rs = RunningStatistics()
1331
- energies = array.array('d')
1332
-
1333
- if progbar:
1334
- from quimb.utils import progbar as Progbar
1335
- pb = Progbar(total=max_samples)
1336
- else:
1337
- pb = None
1338
-
1339
- try:
1340
- for _ in range(max_samples):
1341
- config, _ = self.sampler.sample()
1342
- local_energy = self._compute_local_energy(config)
1343
- rs.update(local_energy)
1344
- energies.append(local_energy)
1345
-
1346
- if pb is not None:
1347
- pb.update()
1348
- err = rs.err
1349
- if err != 0.0:
1350
- pb.set_description(format_number_with_error(rs.mean, err))
1351
-
1352
- if 0.0 < rs.rel_err < rtol:
1353
- break
1354
-
1355
- except KeyboardInterrupt:
1356
- pass
1357
- finally:
1358
- if pb is not None:
1359
- pb.close()
1360
-
1361
- return rs, energies
1362
-
1363
- @default_to_neutral_style
1364
- def plot(
1365
- self,
1366
- figsize=(12, 6),
1367
- yrange_quantile=(0.01, 0.99),
1368
- zoom="auto",
1369
- hlines=(),
1370
- ):
1371
- from matplotlib import pyplot as plt
1372
-
1373
- x = np.arange(len(self.local_energies))
1374
- # these are all views
1375
- y = np.array(self.local_energies)
1376
- ym = np.array(self.energies)
1377
- yerr = np.array(self.energy_errors)
1378
- yplus = ym + yerr
1379
- yminus = ym - yerr
1380
- yv = np.array(self.energy_variances[10:])
1381
-
1382
- fig = plt.figure(figsize=figsize)
1383
- gs = fig.add_gridspec(nrows=2, ncols=3)
1384
-
1385
- ax = fig.add_subplot(gs[:, :2])
1386
- ax.plot(
1387
- x,
1388
- y,
1389
- ".",
1390
- alpha=0.5,
1391
- markersize=1.0,
1392
- zorder=-10,
1393
- color=(0.1, 0.5, 0.7),
1394
- )
1395
- ax.fill_between(
1396
- x, yminus, yplus,
1397
- alpha=0.45,
1398
- color=(0.6, 0.8, 0.6),
1399
- zorder=-11,
1400
- )
1401
- ax.plot(
1402
- x,
1403
- ym,
1404
- "-",
1405
- alpha=0.9,
1406
- zorder=-10,
1407
- linewidth=2,
1408
- color=(0.6, 0.8, 0.6),
1409
- )
1410
- ax.set_ylim(
1411
- np.quantile(y, yrange_quantile[0]),
1412
- np.quantile(y, yrange_quantile[1]),
1413
- )
1414
- ax.set_xlabel("Number of local energy evaluations")
1415
- ax.set_ylabel("Energy per site", color=(0.6, 0.8, 0.6))
1416
-
1417
- if hlines:
1418
- from matplotlib.colors import hsv_to_rgb
1419
-
1420
- hlines = dict(hlines)
1421
- for i, (label, value) in enumerate(hlines.items()):
1422
- color = hsv_to_rgb([(0.1 * i) % 1.0, 0.9, 0.9])
1423
- ax.axhline(value, color=color, ls="--", label=label)
1424
- ax.text(1, value, label, color=color, va="bottom", ha="left")
1425
-
1426
- ax.set_rasterization_zorder(0)
1427
-
1428
- ax_var = fig.add_subplot(gs[1, 2])
1429
- ax_var.plot(
1430
- x[10:],
1431
- yv,
1432
- "-",
1433
- alpha=0.9,
1434
- zorder=-10,
1435
- linewidth=2,
1436
- color=(1.0, 0.7, 0.4),
1437
- )
1438
- ax_var.set_yscale('log')
1439
- ax_var.text(
1440
- 0.9,
1441
- 0.9,
1442
- "Energy variance",
1443
- color=(1.0, 0.7, 0.4),
1444
- horizontalalignment='right',
1445
- verticalalignment='top',
1446
- transform=ax_var.transAxes,
1447
- )
1448
- ax_var.set_rasterization_zorder(0)
1449
-
1450
- if zoom is not None:
1451
- if zoom == "auto":
1452
- zoom = min(10_000, y.size // 2)
1453
-
1454
- ax_zoom = fig.add_subplot(gs[0, 2])
1455
- ax_zoom.fill_between(
1456
- x[-zoom:],
1457
- yminus[-zoom:],
1458
- yplus[-zoom:],
1459
- alpha=0.45,
1460
- color=(0.6, 0.8, 0.6),
1461
- zorder=-11,
1462
- )
1463
- ax_zoom.plot(
1464
- x[-zoom:],
1465
- ym[-zoom:],
1466
- "-",
1467
- alpha=0.9,
1468
- zorder=-10,
1469
- linewidth=2,
1470
- color=(0.6, 0.8, 0.6),
1471
- )
1472
- ax_zoom.text(
1473
- 0.9,
1474
- 0.9,
1475
- "Zoom",
1476
- color=(0.6, 0.8, 0.6),
1477
- horizontalalignment='right',
1478
- verticalalignment='top',
1479
- transform=ax_zoom.transAxes,
1480
- )
1481
- ax_zoom.set_rasterization_zorder(0)
1482
-
1483
- return fig, [ax, ax_zoom, ax_var]