sparse-ir 1.1.6__py3-none-any.whl → 2.0.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sparse_ir/poly.py CHANGED
@@ -1,14 +1,137 @@
1
- # Copyright (C) 2020-2022 Markus Wallerberger, Hiroshi Shinaoka, and others
2
- # SPDX-License-Identifier: MIT
1
+ """
2
+ Piecewise polynomial functionality for SparseIR.
3
+
4
+ This module provides piecewise Legendre polynomial representation and
5
+ their Fourier transforms, which serve as core mathematical infrastructure
6
+ for IR basis functions.
7
+ """
8
+
9
+ from ctypes import c_int, POINTER
3
10
  import numpy as np
4
- from warnings import warn
5
- import numpy.polynomial.legendre as np_legendre
6
- import scipy.special as sp_special
11
+ import weakref
12
+ import threading
13
+
14
+ from pylibsparseir.core import _lib
15
+ from pylibsparseir.core import funcs_eval_single_float64, funcs_eval_single_complex128
16
+ from pylibsparseir.core import funcs_get_size, funcs_get_roots
17
+
18
+ # Global registry to track pointer usage
19
+ _pointer_registry = weakref.WeakSet()
20
+ _registry_lock = threading.Lock()
21
+
22
+ def funcs_get_slice(funcs_ptr, indices):
23
+ status = c_int()
24
+ indices = np.asarray(indices, dtype=np.int32)
25
+ funcs = _lib.spir_funcs_get_slice(funcs_ptr, len(indices), indices.ctypes.data_as(POINTER(c_int)), status)
26
+ if status.value != 0:
27
+ raise RuntimeError(f"Failed to get basis function {indices}: {status.value}")
28
+ return FunctionSet(funcs)
29
+
30
+ def funcs_ft_get_slice(funcs_ptr, indices):
31
+ status = c_int()
32
+ indices = np.asarray(indices, dtype=np.int32)
33
+ funcs = _lib.spir_funcs_get_slice(funcs_ptr, len(indices), indices.ctypes.data_as(POINTER(c_int)), status)
34
+ if status.value != 0:
35
+ raise RuntimeError(f"Failed to get basis function {indices}: {status.value}")
36
+ return FunctionSetFT(funcs)
37
+
38
+ class FunctionSet:
39
+ """Wrapper for basis function evaluation."""
40
+
41
+ def __init__(self, funcs_ptr):
42
+ self._ptr = funcs_ptr
43
+ self._released = False
44
+ # Register this object for safe cleanup
45
+ with _registry_lock:
46
+ _pointer_registry.add(self)
7
47
 
8
- from . import _util
9
- from . import _roots
10
- from . import _gauss
48
+ def __call__(self, x):
49
+ """Evaluate basis functions at given points."""
50
+ if self._released:
51
+ raise RuntimeError("Function set has been released")
52
+ if not isinstance(x, np.ndarray):
53
+ o = funcs_eval_single_float64(self._ptr, x)
54
+ if len(o) == 1:
55
+ return o[0]
56
+ else:
57
+ return o
58
+ else:
59
+ o = np.stack([funcs_eval_single_float64(self._ptr, e) for e in x]).T
60
+ if len(o) == 1:
61
+ return o[0]
62
+ else:
63
+ return o
64
+
65
+ def __getitem__(self, index):
66
+ """Get a single basis function."""
67
+ if self._released:
68
+ raise RuntimeError("Function set has been released")
69
+ sz = funcs_get_size(self._ptr)
70
+ return funcs_get_slice(self._ptr, [index % sz])
71
+
72
+ def release(self):
73
+ """Manually release the function set."""
74
+ if not self._released and self._ptr:
75
+ try:
76
+ _lib.spir_funcs_release(self._ptr)
77
+ except:
78
+ pass
79
+ self._released = True
80
+ self._ptr = None
81
+
82
+ def __del__(self):
83
+ # Only release if we haven't been released yet
84
+ if not self._released:
85
+ self.release()
86
+
87
+ class FunctionSetFT:
88
+ """Wrapper for basis function evaluation."""
89
+
90
+ def __init__(self, funcs_ptr):
91
+ self._ptr = funcs_ptr
92
+ self._released = False
93
+ # Register this object for safe cleanup
94
+ with _registry_lock:
95
+ _pointer_registry.add(self)
11
96
 
97
+ def __call__(self, x):
98
+ """Evaluate basis functions at given points."""
99
+ if self._released:
100
+ raise RuntimeError("Function set has been released")
101
+ if not isinstance(x, np.ndarray):
102
+ o = funcs_eval_single_complex128(self._ptr, x)
103
+ if len(o) == 1:
104
+ return o[0]
105
+ else:
106
+ return o
107
+ else:
108
+ o = np.stack([funcs_eval_single_complex128(self._ptr, e) for e in x]).T
109
+ if len(o) == 1:
110
+ return o[0]
111
+ else:
112
+ return o
113
+
114
+ def __getitem__(self, index):
115
+ """Get a single basis function."""
116
+ if self._released:
117
+ raise RuntimeError("Function set has been released")
118
+ sz = funcs_get_size(self._ptr)
119
+ return funcs_ft_get_slice(self._ptr, [index % sz])
120
+
121
+ def release(self):
122
+ """Manually release the function set."""
123
+ if not self._released and self._ptr:
124
+ try:
125
+ _lib.spir_funcs_release(self._ptr)
126
+ except:
127
+ pass
128
+ self._released = True
129
+ self._ptr = None
130
+
131
+ def __del__(self):
132
+ # Only release if we haven't been released yet
133
+ if not self._released:
134
+ self.release()
12
135
 
13
136
  class PiecewiseLegendrePoly:
14
137
  """Piecewise Legendre polynomial.
@@ -17,92 +140,34 @@ class PiecewiseLegendrePoly:
17
140
  intervals ``S[i] = [a[i], a[i+1]]``, where on each interval the function
18
141
  is expanded in scaled Legendre polynomials.
19
142
  """
20
- def __init__(self, data, knots, dx=None, symm=None):
21
- """Piecewise Legendre polynomial"""
22
- if np.isnan(data).any():
23
- raise ValueError("PiecewiseLegendrePoly: data contains NaN!")
24
- if isinstance(knots, self.__class__):
25
- if dx is not None or symm is None:
26
- raise RuntimeError("wrong arguments")
27
- self.__dict__.update(knots.__dict__)
28
- self.data = data
29
- self.symm = symm
30
- return
31
-
32
- data = np.array(data)
33
- knots = np.array(knots)
34
- polyorder, nsegments = data.shape[:2]
35
- if knots.shape != (nsegments+1,):
36
- raise ValueError("Invalid knots array")
37
- if not (knots[1:] >= knots[:-1]).all():
38
- raise ValueError("Knots must be monotonically increasing")
39
- if symm is None:
40
- # TODO: infer symmetry from data
41
- symm = np.zeros(data.shape[2:])
42
- else:
43
- symm = np.array(symm)
44
- if symm.shape != data.shape[2:]:
45
- raise ValueError("shape mismatch")
46
- if dx is None:
47
- dx = knots[1:] - knots[:-1]
48
- else:
49
- dx = np.array(dx)
50
- if not np.allclose(dx, knots[1:] - knots[:-1]):
51
- raise ValueError("dx must work with knots")
52
-
53
- self.nsegments = nsegments
54
- self.polyorder = polyorder
55
- self.xmin = knots[0]
56
- self.xmax = knots[-1]
57
-
58
- self.knots = knots
59
- self.dx = dx
60
- self.data = data
61
- self.symm = symm
62
- self._xm = .5 * (knots[1:] + knots[:-1])
63
- self._inv_xs = 2/dx
64
- self._norm = np.sqrt(self._inv_xs)
65
-
66
- def __getitem__(self, l):
67
- """Return part of a set of piecewise polynomials"""
68
- new_symm = self.symm[l]
69
- if isinstance(l, tuple):
70
- new_data = self.data[(slice(None), slice(None), *l)]
71
- else:
72
- new_data = self.data[:,:,l]
73
- return self.__class__(new_data, self, symm=new_symm)
143
+
144
+ def __init__(self, funcs: FunctionSet, xmin: float, xmax: float):
145
+ self._funcs = funcs
146
+ self._xmin = xmin
147
+ self._xmax = xmax
74
148
 
75
149
  def __call__(self, x):
76
- """Evaluate polynomial at position x"""
77
- i, xtilde = self._split(np.asarray(x))
78
- data = self.data[:, i]
79
-
80
- # Evaluate for all values of l. x and data array must be
81
- # broadcast'able against each other, so we append dimensions here
82
- func_dims = self.data.ndim - 2
83
- datashape = i.shape + (1,) * func_dims
84
- res = np_legendre.legval(xtilde.reshape(datashape), data, tensor=False)
85
- res *= self._norm[i.reshape(datashape)]
86
-
87
- # Finally, exchange the x and vector dimensions
88
- order = tuple(range(i.ndim, i.ndim + func_dims)) + tuple(range(i.ndim))
89
- return res.transpose(*order)
90
-
91
- def value(self, l, x):
92
- """Return value for l and x."""
93
- if self.data.ndim != 3:
94
- raise ValueError("Only allowed for vector of data")
95
-
96
- l, x = np.broadcast_arrays(l, x)
97
- i, xtilde = self._split(x)
98
- data = self.data[:, i, l]
99
-
100
- # This should now neatly broadcast against each other
101
- res = np_legendre.legval(xtilde, data, tensor=False)
102
- res *= self._norm[i]
103
- return res
104
-
105
- def overlap(self, f, *, rtol=2.3e-16, return_error=False, points=None):
150
+ """Evaluate basis functions at given points."""
151
+ return self._funcs(x)
152
+
153
+
154
+ class PiecewiseLegendrePolyVector:
155
+ """Piecewise Legendre polynomial vector."""
156
+
157
+ def __init__(self, funcs: FunctionSet, xmin: float, xmax: float):
158
+ self._funcs = funcs
159
+ self._xmin = xmin
160
+ self._xmax = xmax
161
+
162
+ def __call__(self, x):
163
+ """Evaluate basis functions at given points."""
164
+ return self._funcs(x)
165
+
166
+ def __getitem__(self, index):
167
+ """Get a single basis function."""
168
+ return PiecewiseLegendrePoly(self._funcs[index], self._xmin, self._xmax)
169
+
170
+ def overlap(self, f, n_points=100):
106
171
  r"""Evaluate overlap integral of this polynomial with function ``f``.
107
172
 
108
173
  Given the function ``f``, evaluate the integral::
@@ -116,86 +181,59 @@ class PiecewiseLegendrePoly:
116
181
  f (callable):
117
182
  function that is called with a point ``x`` and returns ``f(x)``
118
183
  at that position.
119
-
120
- points (sequence of floats)
121
- A sequence of break points in the integration interval
122
- where local difficulties of the integrand may occur
123
- (e.g., singularities, discontinuities)
184
+ n_points (int):
185
+ Number of quadrature points per integration segment.
124
186
 
125
187
  Return:
126
188
  array-like object with shape (poly_dims, f_dims)
127
189
  poly_dims are the shape of the polynomial and f_dims are those
128
190
  of the function f(x).
129
191
  """
130
- int_result, int_error = _compute_overlap(self, f, rtol=rtol, points=points)
131
- if return_error:
132
- return int_result, int_error
133
- else:
134
- return int_result
135
-
136
- def deriv(self, n=1):
137
- """Get polynomial for the n'th derivative"""
138
- ddata = np_legendre.legder(self.data, n)
139
-
140
- _scale_shape = (1, -1) + (1,) * (self.data.ndim - 2)
141
- scale = self._inv_xs ** n
142
- ddata *= scale.reshape(_scale_shape)
143
- return self.__class__(ddata, self, symm=(-1)**n * self.symm)
144
-
145
- def roots(self, alpha=2):
146
- """Find all roots of the piecewise polynomial
147
-
148
- Assume that between each two knots (pieces) there are at most ``alpha``
149
- roots.
150
- """
151
- if self.data.ndim > 2:
152
- raise ValueError("select single polynomial before calling roots()")
153
-
154
- grid = self.knots
155
- xmid = (self.xmax + self.xmin) / 2
156
- if self.symm:
157
- if grid[self.nsegments // 2] == xmid:
158
- grid = grid[self.nsegments//2:]
159
- else:
160
- grid = np.hstack((xmid, grid[grid > xmid]))
161
-
162
- grid = _refine_grid(grid, alpha)
163
- roots = _roots.find_all(self, grid)
164
-
165
- if self.symm == 1:
166
- revroots = (self.xmax + self.xmin) - roots[::-1]
167
- roots = np.hstack((revroots, roots))
168
- elif self.symm == -1:
169
- # There must be a zero at exactly the midpoint, but we may either
170
- # slightly miss it or have a spurious zero
171
- if roots.size:
172
- if roots[0] == xmid or self(xmid) * self.deriv()(xmid) < 0:
173
- roots = roots[1:]
174
- revroots = (self.xmax + self.xmin) - roots[::-1]
175
- roots = np.hstack((revroots, xmid, roots))
176
-
177
- return roots
178
-
179
- @property
180
- def shape(self): return self.data.shape[2:]
181
-
182
- @property
183
- def size(self): return self.data[:1,:1].size
184
-
185
- @property
186
- def ndim(self): return self.data.ndim - 2
187
-
188
- def _split(self, x):
189
- """Split segment"""
190
- x = _util.check_range(x, self.xmin, self.xmax)
191
- i = self.knots.searchsorted(x, 'right').clip(None, self.nsegments)
192
- i -= 1
193
- xtilde = x - self._xm[i]
194
- xtilde *= self._inv_xs[i]
195
- return i, xtilde
196
-
197
-
198
- class PiecewiseLegendreFT:
192
+ from scipy.integrate import fixed_quad
193
+
194
+ xmin = self._xmin
195
+ xmax = self._xmax
196
+ roots = funcs_get_roots(self._funcs._ptr).tolist()
197
+ roots.sort()
198
+
199
+ # Create integration segments
200
+ segments = [xmin] + roots + [xmax]
201
+ segments = sorted(list(set(segments))) # Remove duplicates and sort
202
+
203
+ # Collect all quadrature points and weights
204
+ all_x = []
205
+ all_weights = []
206
+
207
+ for j in range(len(segments) - 1):
208
+ a, b = segments[j], segments[j+1]
209
+ if abs(b - a) < 1e-14: # Skip zero-length segments
210
+ continue
211
+
212
+ # Get Gauss-Legendre quadrature points and weights
213
+ from scipy.special import roots_legendre
214
+ x_quad, w_quad = roots_legendre(n_points)
215
+ # Scale to actual interval
216
+ x_scaled = (b - a) / 2 * x_quad + (a + b) / 2
217
+ w_scaled = w_quad * (b - a) / 2
218
+
219
+ all_x.extend(x_scaled)
220
+ all_weights.extend(w_scaled)
221
+
222
+ # Convert to numpy arrays for batch processing
223
+ all_x = np.array(all_x)
224
+ all_weights = np.array(all_weights)
225
+
226
+ # Evaluate function and polynomials at all points
227
+ f_values = f(all_x) # This should work with array input
228
+ poly_values = self._funcs(all_x) # Shape: (n_polys, n_points)
229
+
230
+ # Compute overlap integrals
231
+ output = np.sum(poly_values * f_values * all_weights, axis=1)
232
+
233
+ return output
234
+
235
+
236
+ class PiecewiseLegendrePolyFT:
199
237
  """Fourier transform of a piecewise Legendre polynomial.
200
238
 
201
239
  For a given frequency index ``n``, the Fourier transform of the Legendre
@@ -207,319 +245,26 @@ class PiecewiseLegendreFT:
207
245
  case ``n`` must be even, or antiperiodically (``freq='odd'``), in which case
208
246
  ``n`` must be odd.
209
247
  """
210
- _DEFAULT_GRID = np.hstack([np.arange(2**6),
211
- (2**np.linspace(6, 35, 16*(35-6)+1)).astype(int)])
212
-
213
- def __init__(self, poly, freq='even', n_asymp=None, power_model=None):
214
- if poly.xmin != -1 or poly.xmax != 1:
215
- raise NotImplementedError("Only interval [-1, 1] supported")
216
- self.poly = poly
217
- self.freq = freq
218
- self.zeta = {'any': None, 'even': 0, 'odd': 1}[freq]
219
- if n_asymp is None:
220
- self.n_asymp = np.inf
221
- self._model = None
222
- else:
223
- self.n_asymp = n_asymp
224
- if power_model is None:
225
- self._model = _power_model(freq, poly)
226
- else:
227
- self._model = power_model
228
-
229
- @property
230
- def shape(self): return self.poly.shape
231
-
232
- @property
233
- def size(self): return self.poly.size
234
-
235
- @property
236
- def ndim(self): return self.poly.ndim
237
-
238
- def __getitem__(self, l):
239
- model = self._model if self._model is None else self._model[l]
240
- return self.__class__(self.poly[l], self.freq, self.n_asymp, model)
241
-
242
- @_util.ravel_argument(last_dim=True)
243
- def __call__(self, n):
244
- """Obtain Fourier transform of polynomial for given frequencies"""
245
- n = _util.check_reduced_matsubara(n, self.zeta)
246
- result = _compute_unl_inner(self.poly, n)
247
-
248
- # We use the asymptotics at frequencies larger than conv_radius
249
- # since it has lower relative error.
250
- cond_outer = np.abs(n) >= self.n_asymp
251
- if cond_outer.any():
252
- n_outer = n[cond_outer]
253
- result[..., cond_outer] = self._model.giw(n_outer).T
254
-
255
- return result
256
-
257
- def extrema(self, *, part=None, grid=None, positive_only=False):
258
- """Obtain extrema of Fourier-transformed polynomial."""
259
- if self.poly.shape:
260
- raise ValueError("select single polynomial")
261
- if grid is None:
262
- grid = self._DEFAULT_GRID
263
-
264
- f = self._func_for_part(part)
265
- x0 = _roots.discrete_extrema(f, grid)
266
- x0 = 2 * x0 + self.zeta
267
- if not positive_only:
268
- x0 = _symmetrize_matsubara(x0)
269
- return x0
270
-
271
- def sign_changes(self, *, part=None, grid=None, positive_only=False):
272
- """Obtain sign changes of Fourier-transformed polynomial."""
273
- if self.poly.shape:
274
- raise ValueError("select single polynomial")
275
- if grid is None:
276
- grid = self._DEFAULT_GRID
277
-
278
- f = self._func_for_part(part)
279
- x0 = _roots.find_all(f, grid, type='discrete')
280
- x0 = 2 * x0 + self.zeta
281
- if not positive_only:
282
- x0 = _symmetrize_matsubara(x0)
283
- return x0
284
-
285
- def _func_for_part(self, part=None):
286
- if part is None:
287
- parity = self.poly.symm
288
- if np.allclose(parity, 1):
289
- part = 'real' if self.zeta == 0 else 'imag'
290
- elif np.allclose(parity, -1):
291
- part = 'imag' if self.zeta == 0 else 'real'
292
- else:
293
- raise ValueError("cannot detect parity.")
294
- if part == 'real':
295
- return lambda n: self(2*n + self.zeta).real
296
- elif part == 'imag':
297
- return lambda n: self(2*n + self.zeta).imag
298
- else:
299
- raise ValueError("part must be either 'real' or 'imag'")
300
248
 
249
+ def __init__(self, funcs: FunctionSetFT):
250
+ assert isinstance(funcs, FunctionSetFT), "funcs must be a FunctionSetFT"
251
+ self._funcs = funcs
301
252
 
253
+ def __call__(self, x):
254
+ """Evaluate basis functions at given points."""
255
+ return self._funcs(x)
302
256
 
303
- def _imag_power(n):
304
- """Imaginary unit raised to an integer power without numerical error"""
305
- n = np.asarray(n)
306
- if not np.issubdtype(n.dtype, np.integer):
307
- raise ValueError("expecting set of integers here")
308
- cycle = np.array([1, 0+1j, -1, 0-1j], complex)
309
- return cycle[n % 4]
310
-
311
-
312
- def _get_tnl(l, w):
313
- r"""Fourier integral of the l-th Legendre polynomial::
314
-
315
- T_l(w) == \int_{-1}^1 dx \exp(iwx) P_l(x)
316
- """
317
- # spherical_jn gives NaN for w < 0, but since we know that P_l(x) is real,
318
- # we simply conjugate the result for w > 0 in these cases.
319
- result = 2 * _imag_power(l) * sp_special.spherical_jn(l, np.abs(w))
320
- np.conjugate(result, out=result, where=w < 0)
321
- return result
322
-
323
-
324
- def _shift_xmid(knots, dx):
325
- r"""Return midpoint relative to the nearest integer plus a shift.
326
-
327
- Return the midpoints ``xmid`` of the segments, as pair ``(diff, shift)``,
328
- where shift is in ``(0,1,-1)`` and ``diff`` is a float such that
329
- ``xmid == shift + diff`` to floating point accuracy.
330
- """
331
- dx_half = dx / 2
332
- xmid_m1 = dx.cumsum() - dx_half
333
- xmid_p1 = -dx[::-1].cumsum()[::-1] + dx_half
334
- xmid_0 = knots[1:] - dx_half
335
-
336
- shift = np.round(xmid_0).astype(int)
337
- diff = np.choose(shift+1, (xmid_m1, xmid_0, xmid_p1))
338
- return diff, shift
339
-
257
+ class PiecewiseLegendrePolyFTVector:
258
+ """Fourier transform of a piecewise Legendre polynomial vector."""
340
259
 
341
- def _phase_stable(poly, wn):
342
- """Phase factor for the piecewise Legendre to Matsubara transform.
260
+ def __init__(self, funcs: FunctionSetFT):
261
+ assert isinstance(funcs, FunctionSetFT), "funcs must be a FunctionSetFT"
262
+ self._funcs = funcs
343
263
 
344
- Compute the following phase factor in a stable way::
264
+ def __call__(self, x: np.ndarray) -> np.ndarray:
265
+ """Evaluate basis functions at given points."""
266
+ return self._funcs(x)
345
267
 
346
- np.exp(1j * np.pi/2 * wn[:,None] * poly.dx.cumsum()[None,:])
347
- """
348
- # A naive implementation is losing precision close to x=1 and/or x=-1:
349
- # there, the multiplication with `wn` results in `wn//4` almost extra turns
350
- # around the unit circle. The cosine and sines will first map those
351
- # back to the interval [-pi, pi) before doing the computation, which loses
352
- # digits in dx. To avoid this, we extract the nearest integer dx.cumsum()
353
- # and rewrite above expression like below.
354
- #
355
- # Now `wn` still results in extra revolutions, but the mapping back does
356
- # not cut digits that were not there in the first place.
357
- xmid_diff, extra_shift = _shift_xmid(poly.knots, poly.dx)
358
-
359
- if np.issubdtype(wn.dtype, np.integer):
360
- shift_arg = wn[None,:] * xmid_diff[:,None]
361
- else:
362
- delta_wn, wn = np.modf(wn)
363
- wn = wn.astype(int)
364
- shift_arg = wn[None,:] * xmid_diff[:,None]
365
- shift_arg += delta_wn[None,:] * (extra_shift + xmid_diff)[:,None]
366
-
367
- phase_shifted = np.exp(0.5j * np.pi * shift_arg)
368
- corr = _imag_power((extra_shift[:,None] + 1) * wn[None,:])
369
- return corr * phase_shifted
370
-
371
-
372
- def _compute_unl_inner(poly, wn):
373
- """Compute piecewise Legendre to Matsubara transform."""
374
- dx_half = poly.dx / 2
375
-
376
- data_flat = poly.data.reshape(*poly.data.shape[:2], -1)
377
- data_sc = data_flat * np.sqrt(dx_half/2)[None,:,None]
378
- p = np.arange(poly.polyorder)
379
-
380
- wred = np.pi/2 * wn
381
- phase_wi = _phase_stable(poly, wn)
382
- t_pin = _get_tnl(p[:,None,None], wred[None,:] * dx_half[:,None]) * phase_wi
383
-
384
- # Perform the following, but faster:
385
- # resulth = einsum('pin,pil->nl', t_pin, data_sc)
386
- npi = poly.polyorder * poly.nsegments
387
- result_flat = (t_pin.reshape(npi,-1).T @ data_sc.reshape(npi,-1)).T
388
- return result_flat.reshape(*poly.data.shape[2:], wn.size)
389
-
390
-
391
- class _PowerModel:
392
- """Model from a high-frequency series expansion::
393
-
394
- A(iw) == sum(A[n] / (iw)**(n+1) for n in range(1, N))
395
-
396
- where ``iw == 1j * pi/2 * wn`` is a reduced imaginary frequency, i.e.,
397
- ``wn`` is an odd/even number for fermionic/bosonic frequencies.
398
- """
399
- def __init__(self, moments):
400
- """Initialize model"""
401
- if moments.ndim == 1:
402
- moments = moments[:, None]
403
- self.moments = np.asarray(moments)
404
- self.nmom, self.nl = self.moments.shape
405
-
406
- @_util.ravel_argument()
407
- def giw(self, wn):
408
- """Return model Green's function for vector of frequencies"""
409
- wn = _util.check_reduced_matsubara(wn)
410
- result_dtype = np.result_type(1j, wn, self.moments)
411
- result = np.zeros((wn.size, self.nl), result_dtype)
412
- inv_iw = 1j * np.pi/2 * wn
413
- np.reciprocal(inv_iw, out=inv_iw, where=(wn != 0))
414
- for mom in self.moments[::-1]:
415
- result += mom
416
- result *= inv_iw[:, None]
417
- return result
418
-
419
- def __getitem__(self, l):
420
- return self.__class__(self.moments[:,l])
421
-
422
-
423
- def _derivs(ppoly, x):
424
- """Evaluate polynomial and its derivatives at specific x"""
425
- yield ppoly(x)
426
- for _ in range(ppoly.polyorder-1):
427
- ppoly = ppoly.deriv()
428
- yield ppoly(x)
429
-
430
-
431
- def _power_moments(stat, deriv_x1):
432
- """Return moments"""
433
- statsign = {'odd': -1, 'even': 1}[stat]
434
- mmax, lmax = deriv_x1.shape
435
- m = np.arange(mmax)[:,None]
436
- l = np.arange(lmax)[None,:]
437
- coeff_lm = ((-1.0)**(m+1) + statsign * (-1.0)**l) * deriv_x1
438
- return -statsign/np.sqrt(2.0) * coeff_lm
439
-
440
-
441
- def _power_model(stat, poly):
442
- deriv_x1 = np.asarray(list(_derivs(poly, x=1)))
443
- if deriv_x1.ndim == 1:
444
- deriv_x1 = deriv_x1[:,None]
445
- moments = _power_moments(stat, deriv_x1)
446
- return _PowerModel(moments)
447
-
448
-
449
- def _refine_grid(knots, alpha):
450
- """Linear refinement of grid"""
451
- result = np.linspace(knots[:-1], knots[1:], alpha, endpoint=False)
452
- return np.hstack((result.T.ravel(), knots[-1]))
453
-
454
-
455
- def _symmetrize_matsubara(x0):
456
- if not x0.size:
457
- return x0
458
- if not (x0[1:] >= x0[:-1]).all():
459
- raise ValueError("set of Matsubara points not ordered")
460
- if not (x0[0] >= 0):
461
- raise ValueError("points must be non-negative")
462
- if x0[0] == 0:
463
- x0 = np.hstack([-x0[::-1], x0[1:]])
464
- else:
465
- x0 = np.hstack([-x0[::-1], x0])
466
- return x0
467
-
468
-
469
- def _compute_overlap(poly, f, rtol=2.3e-16, radix=2, max_refine_levels=40,
470
- max_refine_points=2000, points=None):
471
- base_rule = _gauss.kronrod_31_15()
472
- if points is None:
473
- knots = poly.knots
474
- else:
475
- points = np.asarray(points)
476
- knots = np.unique(np.hstack((poly.knots, points)))
477
- xstart = knots[:-1]
478
- xstop = knots[1:]
479
-
480
- f_shape = None
481
- res_value = 0
482
- res_error = 0
483
- res_magn = 0
484
- for _ in range(max_refine_levels):
485
- #print(f"Level {_}: {xstart.size} segments")
486
- if xstart.size > max_refine_points:
487
- warn("Refinement is too broad, aborting (increase rtol)")
488
- break
489
-
490
- rule = base_rule.reseat(xstart[:, None], xstop[:, None])
491
-
492
- fx = np.array(list(map(f, rule.x.ravel())))
493
- if f_shape is None:
494
- f_shape = fx.shape[1:]
495
- elif fx.shape[1:] != f_shape:
496
- raise ValueError("inconsistent shapes")
497
- fx = fx.reshape(rule.x.shape + (-1,))
498
-
499
- valx = poly(rule.x).reshape(-1, *rule.x.shape, 1) * fx
500
- int21 = (valx[:, :, :, :] * rule.w[:, :, None]).sum(2)
501
- int10 = (valx[:, :, rule.vsel, :] * rule.v[:, :, None]).sum(2)
502
- intdiff = np.abs(int21 - int10)
503
- intmagn = np.abs(int10)
504
-
505
- magn = res_magn + intmagn.sum(1).max(1)
506
- relerror = intdiff.max(2) / magn[:, None]
507
-
508
- xconverged = (relerror <= rtol).all(0)
509
- res_value += int10[:, xconverged].sum(1)
510
- res_error += intdiff[:, xconverged].sum(1)
511
- res_magn += intmagn[:, xconverged].sum(1).max(1)
512
- if xconverged.all():
513
- break
514
-
515
- xrefine = ~xconverged
516
- xstart = xstart[xrefine]
517
- xstop = xstop[xrefine]
518
- xedge = np.linspace(xstart, xstop, radix + 1, axis=-1)
519
- xstart = xedge[:, :-1].ravel()
520
- xstop = xedge[:, 1:].ravel()
521
- else:
522
- warn("Integration did not converge after refinement")
523
-
524
- res_shape = poly.shape + f_shape
525
- return res_value.reshape(res_shape), res_error.reshape(res_shape)
268
+ def __getitem__(self, index):
269
+ """Get a single basis function."""
270
+ return PiecewiseLegendrePolyFT(self._funcs[index])