sparse-ir 1.1.7__py3-none-any.whl → 2.0.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparse_ir/__init__.py +33 -15
- sparse_ir/abstract.py +70 -82
- sparse_ir/augment.py +27 -17
- sparse_ir/basis.py +130 -239
- sparse_ir/basis_set.py +99 -57
- sparse_ir/dlr.py +131 -88
- sparse_ir/kernel.py +54 -477
- sparse_ir/poly.py +221 -476
- sparse_ir/sampling.py +260 -371
- sparse_ir/sve.py +56 -358
- sparse_ir-2.0.0a2.dist-info/METADATA +23 -0
- sparse_ir-2.0.0a2.dist-info/RECORD +16 -0
- {sparse_ir-1.1.7.dist-info → sparse_ir-2.0.0a2.dist-info}/WHEEL +1 -1
- sparse_ir/_gauss.py +0 -260
- sparse_ir/_roots.py +0 -140
- sparse_ir/adapter.py +0 -267
- sparse_ir/svd.py +0 -102
- sparse_ir-1.1.7.dist-info/METADATA +0 -155
- sparse_ir-1.1.7.dist-info/RECORD +0 -20
- {sparse_ir-1.1.7.dist-info → sparse_ir-2.0.0a2.dist-info/licenses}/LICENSE.txt +0 -0
- {sparse_ir-1.1.7.dist-info → sparse_ir-2.0.0a2.dist-info}/top_level.txt +0 -0
sparse_ir/poly.py
CHANGED
@@ -1,14 +1,137 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
"""
|
2
|
+
Piecewise polynomial functionality for SparseIR.
|
3
|
+
|
4
|
+
This module provides piecewise Legendre polynomial representation and
|
5
|
+
their Fourier transforms, which serve as core mathematical infrastructure
|
6
|
+
for IR basis functions.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from ctypes import c_int, POINTER
|
3
10
|
import numpy as np
|
4
|
-
|
5
|
-
import
|
6
|
-
|
11
|
+
import weakref
|
12
|
+
import threading
|
13
|
+
|
14
|
+
from pylibsparseir.core import _lib
|
15
|
+
from pylibsparseir.core import funcs_eval_single_float64, funcs_eval_single_complex128
|
16
|
+
from pylibsparseir.core import funcs_get_size, funcs_get_roots
|
17
|
+
|
18
|
+
# Global registry to track pointer usage
|
19
|
+
_pointer_registry = weakref.WeakSet()
|
20
|
+
_registry_lock = threading.Lock()
|
21
|
+
|
22
|
+
def funcs_get_slice(funcs_ptr, indices):
|
23
|
+
status = c_int()
|
24
|
+
indices = np.asarray(indices, dtype=np.int32)
|
25
|
+
funcs = _lib.spir_funcs_get_slice(funcs_ptr, len(indices), indices.ctypes.data_as(POINTER(c_int)), status)
|
26
|
+
if status.value != 0:
|
27
|
+
raise RuntimeError(f"Failed to get basis function {indices}: {status.value}")
|
28
|
+
return FunctionSet(funcs)
|
29
|
+
|
30
|
+
def funcs_ft_get_slice(funcs_ptr, indices):
|
31
|
+
status = c_int()
|
32
|
+
indices = np.asarray(indices, dtype=np.int32)
|
33
|
+
funcs = _lib.spir_funcs_get_slice(funcs_ptr, len(indices), indices.ctypes.data_as(POINTER(c_int)), status)
|
34
|
+
if status.value != 0:
|
35
|
+
raise RuntimeError(f"Failed to get basis function {indices}: {status.value}")
|
36
|
+
return FunctionSetFT(funcs)
|
37
|
+
|
38
|
+
class FunctionSet:
|
39
|
+
"""Wrapper for basis function evaluation."""
|
40
|
+
|
41
|
+
def __init__(self, funcs_ptr):
|
42
|
+
self._ptr = funcs_ptr
|
43
|
+
self._released = False
|
44
|
+
# Register this object for safe cleanup
|
45
|
+
with _registry_lock:
|
46
|
+
_pointer_registry.add(self)
|
7
47
|
|
8
|
-
|
9
|
-
|
10
|
-
|
48
|
+
def __call__(self, x):
|
49
|
+
"""Evaluate basis functions at given points."""
|
50
|
+
if self._released:
|
51
|
+
raise RuntimeError("Function set has been released")
|
52
|
+
if not isinstance(x, np.ndarray):
|
53
|
+
o = funcs_eval_single_float64(self._ptr, x)
|
54
|
+
if len(o) == 1:
|
55
|
+
return o[0]
|
56
|
+
else:
|
57
|
+
return o
|
58
|
+
else:
|
59
|
+
o = np.stack([funcs_eval_single_float64(self._ptr, e) for e in x]).T
|
60
|
+
if len(o) == 1:
|
61
|
+
return o[0]
|
62
|
+
else:
|
63
|
+
return o
|
64
|
+
|
65
|
+
def __getitem__(self, index):
|
66
|
+
"""Get a single basis function."""
|
67
|
+
if self._released:
|
68
|
+
raise RuntimeError("Function set has been released")
|
69
|
+
sz = funcs_get_size(self._ptr)
|
70
|
+
return funcs_get_slice(self._ptr, [index % sz])
|
71
|
+
|
72
|
+
def release(self):
|
73
|
+
"""Manually release the function set."""
|
74
|
+
if not self._released and self._ptr:
|
75
|
+
try:
|
76
|
+
_lib.spir_funcs_release(self._ptr)
|
77
|
+
except:
|
78
|
+
pass
|
79
|
+
self._released = True
|
80
|
+
self._ptr = None
|
81
|
+
|
82
|
+
def __del__(self):
|
83
|
+
# Only release if we haven't been released yet
|
84
|
+
if not self._released:
|
85
|
+
self.release()
|
86
|
+
|
87
|
+
class FunctionSetFT:
|
88
|
+
"""Wrapper for basis function evaluation."""
|
89
|
+
|
90
|
+
def __init__(self, funcs_ptr):
|
91
|
+
self._ptr = funcs_ptr
|
92
|
+
self._released = False
|
93
|
+
# Register this object for safe cleanup
|
94
|
+
with _registry_lock:
|
95
|
+
_pointer_registry.add(self)
|
11
96
|
|
97
|
+
def __call__(self, x):
|
98
|
+
"""Evaluate basis functions at given points."""
|
99
|
+
if self._released:
|
100
|
+
raise RuntimeError("Function set has been released")
|
101
|
+
if not isinstance(x, np.ndarray):
|
102
|
+
o = funcs_eval_single_complex128(self._ptr, x)
|
103
|
+
if len(o) == 1:
|
104
|
+
return o[0]
|
105
|
+
else:
|
106
|
+
return o
|
107
|
+
else:
|
108
|
+
o = np.stack([funcs_eval_single_complex128(self._ptr, e) for e in x]).T
|
109
|
+
if len(o) == 1:
|
110
|
+
return o[0]
|
111
|
+
else:
|
112
|
+
return o
|
113
|
+
|
114
|
+
def __getitem__(self, index):
|
115
|
+
"""Get a single basis function."""
|
116
|
+
if self._released:
|
117
|
+
raise RuntimeError("Function set has been released")
|
118
|
+
sz = funcs_get_size(self._ptr)
|
119
|
+
return funcs_ft_get_slice(self._ptr, [index % sz])
|
120
|
+
|
121
|
+
def release(self):
|
122
|
+
"""Manually release the function set."""
|
123
|
+
if not self._released and self._ptr:
|
124
|
+
try:
|
125
|
+
_lib.spir_funcs_release(self._ptr)
|
126
|
+
except:
|
127
|
+
pass
|
128
|
+
self._released = True
|
129
|
+
self._ptr = None
|
130
|
+
|
131
|
+
def __del__(self):
|
132
|
+
# Only release if we haven't been released yet
|
133
|
+
if not self._released:
|
134
|
+
self.release()
|
12
135
|
|
13
136
|
class PiecewiseLegendrePoly:
|
14
137
|
"""Piecewise Legendre polynomial.
|
@@ -17,92 +140,34 @@ class PiecewiseLegendrePoly:
|
|
17
140
|
intervals ``S[i] = [a[i], a[i+1]]``, where on each interval the function
|
18
141
|
is expanded in scaled Legendre polynomials.
|
19
142
|
"""
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
if dx is not None or symm is None:
|
26
|
-
raise RuntimeError("wrong arguments")
|
27
|
-
self.__dict__.update(knots.__dict__)
|
28
|
-
self.data = data
|
29
|
-
self.symm = symm
|
30
|
-
return
|
31
|
-
|
32
|
-
data = np.array(data)
|
33
|
-
knots = np.array(knots)
|
34
|
-
polyorder, nsegments = data.shape[:2]
|
35
|
-
if knots.shape != (nsegments+1,):
|
36
|
-
raise ValueError("Invalid knots array")
|
37
|
-
if not (knots[1:] >= knots[:-1]).all():
|
38
|
-
raise ValueError("Knots must be monotonically increasing")
|
39
|
-
if symm is None:
|
40
|
-
# TODO: infer symmetry from data
|
41
|
-
symm = np.zeros(data.shape[2:])
|
42
|
-
else:
|
43
|
-
symm = np.array(symm)
|
44
|
-
if symm.shape != data.shape[2:]:
|
45
|
-
raise ValueError("shape mismatch")
|
46
|
-
if dx is None:
|
47
|
-
dx = knots[1:] - knots[:-1]
|
48
|
-
else:
|
49
|
-
dx = np.array(dx)
|
50
|
-
if not np.allclose(dx, knots[1:] - knots[:-1]):
|
51
|
-
raise ValueError("dx must work with knots")
|
52
|
-
|
53
|
-
self.nsegments = nsegments
|
54
|
-
self.polyorder = polyorder
|
55
|
-
self.xmin = knots[0]
|
56
|
-
self.xmax = knots[-1]
|
57
|
-
|
58
|
-
self.knots = knots
|
59
|
-
self.dx = dx
|
60
|
-
self.data = data
|
61
|
-
self.symm = symm
|
62
|
-
self._xm = .5 * (knots[1:] + knots[:-1])
|
63
|
-
self._inv_xs = 2/dx
|
64
|
-
self._norm = np.sqrt(self._inv_xs)
|
65
|
-
|
66
|
-
def __getitem__(self, l):
|
67
|
-
"""Return part of a set of piecewise polynomials"""
|
68
|
-
new_symm = self.symm[l]
|
69
|
-
if isinstance(l, tuple):
|
70
|
-
new_data = self.data[(slice(None), slice(None), *l)]
|
71
|
-
else:
|
72
|
-
new_data = self.data[:,:,l]
|
73
|
-
return self.__class__(new_data, self, symm=new_symm)
|
143
|
+
|
144
|
+
def __init__(self, funcs: FunctionSet, xmin: float, xmax: float):
|
145
|
+
self._funcs = funcs
|
146
|
+
self._xmin = xmin
|
147
|
+
self._xmax = xmax
|
74
148
|
|
75
149
|
def __call__(self, x):
|
76
|
-
"""Evaluate
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
i, xtilde = self._split(x)
|
98
|
-
data = self.data[:, i, l]
|
99
|
-
|
100
|
-
# This should now neatly broadcast against each other
|
101
|
-
res = np_legendre.legval(xtilde, data, tensor=False)
|
102
|
-
res *= self._norm[i]
|
103
|
-
return res
|
104
|
-
|
105
|
-
def overlap(self, f, *, rtol=2.3e-16, return_error=False, points=None):
|
150
|
+
"""Evaluate basis functions at given points."""
|
151
|
+
return self._funcs(x)
|
152
|
+
|
153
|
+
|
154
|
+
class PiecewiseLegendrePolyVector:
|
155
|
+
"""Piecewise Legendre polynomial vector."""
|
156
|
+
|
157
|
+
def __init__(self, funcs: FunctionSet, xmin: float, xmax: float):
|
158
|
+
self._funcs = funcs
|
159
|
+
self._xmin = xmin
|
160
|
+
self._xmax = xmax
|
161
|
+
|
162
|
+
def __call__(self, x):
|
163
|
+
"""Evaluate basis functions at given points."""
|
164
|
+
return self._funcs(x)
|
165
|
+
|
166
|
+
def __getitem__(self, index):
|
167
|
+
"""Get a single basis function."""
|
168
|
+
return PiecewiseLegendrePoly(self._funcs[index], self._xmin, self._xmax)
|
169
|
+
|
170
|
+
def overlap(self, f, n_points=100):
|
106
171
|
r"""Evaluate overlap integral of this polynomial with function ``f``.
|
107
172
|
|
108
173
|
Given the function ``f``, evaluate the integral::
|
@@ -116,86 +181,59 @@ class PiecewiseLegendrePoly:
|
|
116
181
|
f (callable):
|
117
182
|
function that is called with a point ``x`` and returns ``f(x)``
|
118
183
|
at that position.
|
119
|
-
|
120
|
-
|
121
|
-
A sequence of break points in the integration interval
|
122
|
-
where local difficulties of the integrand may occur
|
123
|
-
(e.g., singularities, discontinuities)
|
184
|
+
n_points (int):
|
185
|
+
Number of quadrature points per integration segment.
|
124
186
|
|
125
187
|
Return:
|
126
188
|
array-like object with shape (poly_dims, f_dims)
|
127
189
|
poly_dims are the shape of the polynomial and f_dims are those
|
128
190
|
of the function f(x).
|
129
191
|
"""
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
roots = np.hstack((revroots, xmid, roots))
|
176
|
-
|
177
|
-
return roots
|
178
|
-
|
179
|
-
@property
|
180
|
-
def shape(self): return self.data.shape[2:]
|
181
|
-
|
182
|
-
@property
|
183
|
-
def size(self): return self.data[:1,:1].size
|
184
|
-
|
185
|
-
@property
|
186
|
-
def ndim(self): return self.data.ndim - 2
|
187
|
-
|
188
|
-
def _split(self, x):
|
189
|
-
"""Split segment"""
|
190
|
-
x = _util.check_range(x, self.xmin, self.xmax)
|
191
|
-
i = self.knots.searchsorted(x, 'right').clip(None, self.nsegments)
|
192
|
-
i -= 1
|
193
|
-
xtilde = x - self._xm[i]
|
194
|
-
xtilde *= self._inv_xs[i]
|
195
|
-
return i, xtilde
|
196
|
-
|
197
|
-
|
198
|
-
class PiecewiseLegendreFT:
|
192
|
+
from scipy.integrate import fixed_quad
|
193
|
+
|
194
|
+
xmin = self._xmin
|
195
|
+
xmax = self._xmax
|
196
|
+
roots = funcs_get_roots(self._funcs._ptr).tolist()
|
197
|
+
roots.sort()
|
198
|
+
|
199
|
+
# Create integration segments
|
200
|
+
segments = [xmin] + roots + [xmax]
|
201
|
+
segments = sorted(list(set(segments))) # Remove duplicates and sort
|
202
|
+
|
203
|
+
# Collect all quadrature points and weights
|
204
|
+
all_x = []
|
205
|
+
all_weights = []
|
206
|
+
|
207
|
+
for j in range(len(segments) - 1):
|
208
|
+
a, b = segments[j], segments[j+1]
|
209
|
+
if abs(b - a) < 1e-14: # Skip zero-length segments
|
210
|
+
continue
|
211
|
+
|
212
|
+
# Get Gauss-Legendre quadrature points and weights
|
213
|
+
from scipy.special import roots_legendre
|
214
|
+
x_quad, w_quad = roots_legendre(n_points)
|
215
|
+
# Scale to actual interval
|
216
|
+
x_scaled = (b - a) / 2 * x_quad + (a + b) / 2
|
217
|
+
w_scaled = w_quad * (b - a) / 2
|
218
|
+
|
219
|
+
all_x.extend(x_scaled)
|
220
|
+
all_weights.extend(w_scaled)
|
221
|
+
|
222
|
+
# Convert to numpy arrays for batch processing
|
223
|
+
all_x = np.array(all_x)
|
224
|
+
all_weights = np.array(all_weights)
|
225
|
+
|
226
|
+
# Evaluate function and polynomials at all points
|
227
|
+
f_values = f(all_x) # This should work with array input
|
228
|
+
poly_values = self._funcs(all_x) # Shape: (n_polys, n_points)
|
229
|
+
|
230
|
+
# Compute overlap integrals
|
231
|
+
output = np.sum(poly_values * f_values * all_weights, axis=1)
|
232
|
+
|
233
|
+
return output
|
234
|
+
|
235
|
+
|
236
|
+
class PiecewiseLegendrePolyFT:
|
199
237
|
"""Fourier transform of a piecewise Legendre polynomial.
|
200
238
|
|
201
239
|
For a given frequency index ``n``, the Fourier transform of the Legendre
|
@@ -207,319 +245,26 @@ class PiecewiseLegendreFT:
|
|
207
245
|
case ``n`` must be even, or antiperiodically (``freq='odd'``), in which case
|
208
246
|
``n`` must be odd.
|
209
247
|
"""
|
210
|
-
_DEFAULT_GRID = np.hstack([np.arange(2**6),
|
211
|
-
(2**np.linspace(6, 35, 16*(35-6)+1)).astype(int)])
|
212
|
-
|
213
|
-
def __init__(self, poly, freq='even', n_asymp=None, power_model=None):
|
214
|
-
if poly.xmin != -1 or poly.xmax != 1:
|
215
|
-
raise NotImplementedError("Only interval [-1, 1] supported")
|
216
|
-
self.poly = poly
|
217
|
-
self.freq = freq
|
218
|
-
self.zeta = {'any': None, 'even': 0, 'odd': 1}[freq]
|
219
|
-
if n_asymp is None:
|
220
|
-
self.n_asymp = np.inf
|
221
|
-
self._model = None
|
222
|
-
else:
|
223
|
-
self.n_asymp = n_asymp
|
224
|
-
if power_model is None:
|
225
|
-
self._model = _power_model(freq, poly)
|
226
|
-
else:
|
227
|
-
self._model = power_model
|
228
|
-
|
229
|
-
@property
|
230
|
-
def shape(self): return self.poly.shape
|
231
|
-
|
232
|
-
@property
|
233
|
-
def size(self): return self.poly.size
|
234
|
-
|
235
|
-
@property
|
236
|
-
def ndim(self): return self.poly.ndim
|
237
|
-
|
238
|
-
def __getitem__(self, l):
|
239
|
-
model = self._model if self._model is None else self._model[l]
|
240
|
-
return self.__class__(self.poly[l], self.freq, self.n_asymp, model)
|
241
|
-
|
242
|
-
@_util.ravel_argument(last_dim=True)
|
243
|
-
def __call__(self, n):
|
244
|
-
"""Obtain Fourier transform of polynomial for given frequencies"""
|
245
|
-
n = _util.check_reduced_matsubara(n, self.zeta)
|
246
|
-
result = _compute_unl_inner(self.poly, n)
|
247
|
-
|
248
|
-
# We use the asymptotics at frequencies larger than conv_radius
|
249
|
-
# since it has lower relative error.
|
250
|
-
cond_outer = np.abs(n) >= self.n_asymp
|
251
|
-
if cond_outer.any():
|
252
|
-
n_outer = n[cond_outer]
|
253
|
-
result[..., cond_outer] = self._model.giw(n_outer).T
|
254
|
-
|
255
|
-
return result
|
256
|
-
|
257
|
-
def extrema(self, *, part=None, grid=None, positive_only=False):
|
258
|
-
"""Obtain extrema of Fourier-transformed polynomial."""
|
259
|
-
if self.poly.shape:
|
260
|
-
raise ValueError("select single polynomial")
|
261
|
-
if grid is None:
|
262
|
-
grid = self._DEFAULT_GRID
|
263
|
-
|
264
|
-
f = self._func_for_part(part)
|
265
|
-
x0 = _roots.discrete_extrema(f, grid)
|
266
|
-
x0 = 2 * x0 + self.zeta
|
267
|
-
if not positive_only:
|
268
|
-
x0 = _symmetrize_matsubara(x0)
|
269
|
-
return x0
|
270
|
-
|
271
|
-
def sign_changes(self, *, part=None, grid=None, positive_only=False):
|
272
|
-
"""Obtain sign changes of Fourier-transformed polynomial."""
|
273
|
-
if self.poly.shape:
|
274
|
-
raise ValueError("select single polynomial")
|
275
|
-
if grid is None:
|
276
|
-
grid = self._DEFAULT_GRID
|
277
|
-
|
278
|
-
f = self._func_for_part(part)
|
279
|
-
x0 = _roots.find_all(f, grid, type='discrete')
|
280
|
-
x0 = 2 * x0 + self.zeta
|
281
|
-
if not positive_only:
|
282
|
-
x0 = _symmetrize_matsubara(x0)
|
283
|
-
return x0
|
284
|
-
|
285
|
-
def _func_for_part(self, part=None):
|
286
|
-
if part is None:
|
287
|
-
parity = self.poly.symm
|
288
|
-
if np.allclose(parity, 1):
|
289
|
-
part = 'real' if self.zeta == 0 else 'imag'
|
290
|
-
elif np.allclose(parity, -1):
|
291
|
-
part = 'imag' if self.zeta == 0 else 'real'
|
292
|
-
else:
|
293
|
-
raise ValueError("cannot detect parity.")
|
294
|
-
if part == 'real':
|
295
|
-
return lambda n: self(2*n + self.zeta).real
|
296
|
-
elif part == 'imag':
|
297
|
-
return lambda n: self(2*n + self.zeta).imag
|
298
|
-
else:
|
299
|
-
raise ValueError("part must be either 'real' or 'imag'")
|
300
248
|
|
249
|
+
def __init__(self, funcs: FunctionSetFT):
|
250
|
+
assert isinstance(funcs, FunctionSetFT), "funcs must be a FunctionSetFT"
|
251
|
+
self._funcs = funcs
|
301
252
|
|
253
|
+
def __call__(self, x):
|
254
|
+
"""Evaluate basis functions at given points."""
|
255
|
+
return self._funcs(x)
|
302
256
|
|
303
|
-
|
304
|
-
"""
|
305
|
-
n = np.asarray(n)
|
306
|
-
if not np.issubdtype(n.dtype, np.integer):
|
307
|
-
raise ValueError("expecting set of integers here")
|
308
|
-
cycle = np.array([1, 0+1j, -1, 0-1j], complex)
|
309
|
-
return cycle[n % 4]
|
310
|
-
|
311
|
-
|
312
|
-
def _get_tnl(l, w):
|
313
|
-
r"""Fourier integral of the l-th Legendre polynomial::
|
314
|
-
|
315
|
-
T_l(w) == \int_{-1}^1 dx \exp(iwx) P_l(x)
|
316
|
-
"""
|
317
|
-
# spherical_jn gives NaN for w < 0, but since we know that P_l(x) is real,
|
318
|
-
# we simply conjugate the result for w > 0 in these cases.
|
319
|
-
result = 2 * _imag_power(l) * sp_special.spherical_jn(l, np.abs(w))
|
320
|
-
np.conjugate(result, out=result, where=w < 0)
|
321
|
-
return result
|
322
|
-
|
323
|
-
|
324
|
-
def _shift_xmid(knots, dx):
|
325
|
-
r"""Return midpoint relative to the nearest integer plus a shift.
|
326
|
-
|
327
|
-
Return the midpoints ``xmid`` of the segments, as pair ``(diff, shift)``,
|
328
|
-
where shift is in ``(0,1,-1)`` and ``diff`` is a float such that
|
329
|
-
``xmid == shift + diff`` to floating point accuracy.
|
330
|
-
"""
|
331
|
-
dx_half = dx / 2
|
332
|
-
xmid_m1 = dx.cumsum() - dx_half
|
333
|
-
xmid_p1 = -dx[::-1].cumsum()[::-1] + dx_half
|
334
|
-
xmid_0 = knots[1:] - dx_half
|
335
|
-
|
336
|
-
shift = np.round(xmid_0).astype(int)
|
337
|
-
diff = np.choose(shift+1, (xmid_m1, xmid_0, xmid_p1))
|
338
|
-
return diff, shift
|
339
|
-
|
257
|
+
class PiecewiseLegendrePolyFTVector:
|
258
|
+
"""Fourier transform of a piecewise Legendre polynomial vector."""
|
340
259
|
|
341
|
-
def
|
342
|
-
|
260
|
+
def __init__(self, funcs: FunctionSetFT):
|
261
|
+
assert isinstance(funcs, FunctionSetFT), "funcs must be a FunctionSetFT"
|
262
|
+
self._funcs = funcs
|
343
263
|
|
344
|
-
|
264
|
+
def __call__(self, x: np.ndarray) -> np.ndarray:
|
265
|
+
"""Evaluate basis functions at given points."""
|
266
|
+
return self._funcs(x)
|
345
267
|
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
# there, the multiplication with `wn` results in `wn//4` almost extra turns
|
350
|
-
# around the unit circle. The cosine and sines will first map those
|
351
|
-
# back to the interval [-pi, pi) before doing the computation, which loses
|
352
|
-
# digits in dx. To avoid this, we extract the nearest integer dx.cumsum()
|
353
|
-
# and rewrite above expression like below.
|
354
|
-
#
|
355
|
-
# Now `wn` still results in extra revolutions, but the mapping back does
|
356
|
-
# not cut digits that were not there in the first place.
|
357
|
-
xmid_diff, extra_shift = _shift_xmid(poly.knots, poly.dx)
|
358
|
-
|
359
|
-
if np.issubdtype(wn.dtype, np.integer):
|
360
|
-
shift_arg = wn[None,:] * xmid_diff[:,None]
|
361
|
-
else:
|
362
|
-
delta_wn, wn = np.modf(wn)
|
363
|
-
wn = wn.astype(int)
|
364
|
-
shift_arg = wn[None,:] * xmid_diff[:,None]
|
365
|
-
shift_arg += delta_wn[None,:] * (extra_shift + xmid_diff)[:,None]
|
366
|
-
|
367
|
-
phase_shifted = np.exp(0.5j * np.pi * shift_arg)
|
368
|
-
corr = _imag_power((extra_shift[:,None] + 1) * wn[None,:])
|
369
|
-
return corr * phase_shifted
|
370
|
-
|
371
|
-
|
372
|
-
def _compute_unl_inner(poly, wn):
|
373
|
-
"""Compute piecewise Legendre to Matsubara transform."""
|
374
|
-
dx_half = poly.dx / 2
|
375
|
-
|
376
|
-
data_flat = poly.data.reshape(*poly.data.shape[:2], -1)
|
377
|
-
data_sc = data_flat * np.sqrt(dx_half/2)[None,:,None]
|
378
|
-
p = np.arange(poly.polyorder)
|
379
|
-
|
380
|
-
wred = np.pi/2 * wn
|
381
|
-
phase_wi = _phase_stable(poly, wn)
|
382
|
-
t_pin = _get_tnl(p[:,None,None], wred[None,:] * dx_half[:,None]) * phase_wi
|
383
|
-
|
384
|
-
# Perform the following, but faster:
|
385
|
-
# resulth = einsum('pin,pil->nl', t_pin, data_sc)
|
386
|
-
npi = poly.polyorder * poly.nsegments
|
387
|
-
result_flat = (t_pin.reshape(npi,-1).T @ data_sc.reshape(npi,-1)).T
|
388
|
-
return result_flat.reshape(*poly.data.shape[2:], wn.size)
|
389
|
-
|
390
|
-
|
391
|
-
class _PowerModel:
|
392
|
-
"""Model from a high-frequency series expansion::
|
393
|
-
|
394
|
-
A(iw) == sum(A[n] / (iw)**(n+1) for n in range(1, N))
|
395
|
-
|
396
|
-
where ``iw == 1j * pi/2 * wn`` is a reduced imaginary frequency, i.e.,
|
397
|
-
``wn`` is an odd/even number for fermionic/bosonic frequencies.
|
398
|
-
"""
|
399
|
-
def __init__(self, moments):
|
400
|
-
"""Initialize model"""
|
401
|
-
if moments.ndim == 1:
|
402
|
-
moments = moments[:, None]
|
403
|
-
self.moments = np.asarray(moments)
|
404
|
-
self.nmom, self.nl = self.moments.shape
|
405
|
-
|
406
|
-
@_util.ravel_argument()
|
407
|
-
def giw(self, wn):
|
408
|
-
"""Return model Green's function for vector of frequencies"""
|
409
|
-
wn = _util.check_reduced_matsubara(wn)
|
410
|
-
result_dtype = np.result_type(1j, wn, self.moments)
|
411
|
-
result = np.zeros((wn.size, self.nl), result_dtype)
|
412
|
-
inv_iw = 1j * np.pi/2 * wn
|
413
|
-
np.reciprocal(inv_iw, out=inv_iw, where=(wn != 0))
|
414
|
-
for mom in self.moments[::-1]:
|
415
|
-
result += mom
|
416
|
-
result *= inv_iw[:, None]
|
417
|
-
return result
|
418
|
-
|
419
|
-
def __getitem__(self, l):
|
420
|
-
return self.__class__(self.moments[:,l])
|
421
|
-
|
422
|
-
|
423
|
-
def _derivs(ppoly, x):
|
424
|
-
"""Evaluate polynomial and its derivatives at specific x"""
|
425
|
-
yield ppoly(x)
|
426
|
-
for _ in range(ppoly.polyorder-1):
|
427
|
-
ppoly = ppoly.deriv()
|
428
|
-
yield ppoly(x)
|
429
|
-
|
430
|
-
|
431
|
-
def _power_moments(stat, deriv_x1):
|
432
|
-
"""Return moments"""
|
433
|
-
statsign = {'odd': -1, 'even': 1}[stat]
|
434
|
-
mmax, lmax = deriv_x1.shape
|
435
|
-
m = np.arange(mmax)[:,None]
|
436
|
-
l = np.arange(lmax)[None,:]
|
437
|
-
coeff_lm = ((-1.0)**(m+1) + statsign * (-1.0)**l) * deriv_x1
|
438
|
-
return -statsign/np.sqrt(2.0) * coeff_lm
|
439
|
-
|
440
|
-
|
441
|
-
def _power_model(stat, poly):
|
442
|
-
deriv_x1 = np.asarray(list(_derivs(poly, x=1)))
|
443
|
-
if deriv_x1.ndim == 1:
|
444
|
-
deriv_x1 = deriv_x1[:,None]
|
445
|
-
moments = _power_moments(stat, deriv_x1)
|
446
|
-
return _PowerModel(moments)
|
447
|
-
|
448
|
-
|
449
|
-
def _refine_grid(knots, alpha):
|
450
|
-
"""Linear refinement of grid"""
|
451
|
-
result = np.linspace(knots[:-1], knots[1:], alpha, endpoint=False)
|
452
|
-
return np.hstack((result.T.ravel(), knots[-1]))
|
453
|
-
|
454
|
-
|
455
|
-
def _symmetrize_matsubara(x0):
|
456
|
-
if not x0.size:
|
457
|
-
return x0
|
458
|
-
if not (x0[1:] >= x0[:-1]).all():
|
459
|
-
raise ValueError("set of Matsubara points not ordered")
|
460
|
-
if not (x0[0] >= 0):
|
461
|
-
raise ValueError("points must be non-negative")
|
462
|
-
if x0[0] == 0:
|
463
|
-
x0 = np.hstack([-x0[::-1], x0[1:]])
|
464
|
-
else:
|
465
|
-
x0 = np.hstack([-x0[::-1], x0])
|
466
|
-
return x0
|
467
|
-
|
468
|
-
|
469
|
-
def _compute_overlap(poly, f, rtol=2.3e-16, radix=2, max_refine_levels=40,
|
470
|
-
max_refine_points=2000, points=None):
|
471
|
-
base_rule = _gauss.kronrod_31_15()
|
472
|
-
if points is None:
|
473
|
-
knots = poly.knots
|
474
|
-
else:
|
475
|
-
points = np.asarray(points)
|
476
|
-
knots = np.unique(np.hstack((poly.knots, points)))
|
477
|
-
xstart = knots[:-1]
|
478
|
-
xstop = knots[1:]
|
479
|
-
|
480
|
-
f_shape = None
|
481
|
-
res_value = 0
|
482
|
-
res_error = 0
|
483
|
-
res_magn = 0
|
484
|
-
for _ in range(max_refine_levels):
|
485
|
-
#print(f"Level {_}: {xstart.size} segments")
|
486
|
-
if xstart.size > max_refine_points:
|
487
|
-
warn("Refinement is too broad, aborting (increase rtol)")
|
488
|
-
break
|
489
|
-
|
490
|
-
rule = base_rule.reseat(xstart[:, None], xstop[:, None])
|
491
|
-
|
492
|
-
fx = np.array(list(map(f, rule.x.ravel())))
|
493
|
-
if f_shape is None:
|
494
|
-
f_shape = fx.shape[1:]
|
495
|
-
elif fx.shape[1:] != f_shape:
|
496
|
-
raise ValueError("inconsistent shapes")
|
497
|
-
fx = fx.reshape(rule.x.shape + (-1,))
|
498
|
-
|
499
|
-
valx = poly(rule.x).reshape(-1, *rule.x.shape, 1) * fx
|
500
|
-
int21 = (valx[:, :, :, :] * rule.w[:, :, None]).sum(2)
|
501
|
-
int10 = (valx[:, :, rule.vsel, :] * rule.v[:, :, None]).sum(2)
|
502
|
-
intdiff = np.abs(int21 - int10)
|
503
|
-
intmagn = np.abs(int10)
|
504
|
-
|
505
|
-
magn = res_magn + intmagn.sum(1).max(1)
|
506
|
-
relerror = intdiff.max(2) / magn[:, None]
|
507
|
-
|
508
|
-
xconverged = (relerror <= rtol).all(0)
|
509
|
-
res_value += int10[:, xconverged].sum(1)
|
510
|
-
res_error += intdiff[:, xconverged].sum(1)
|
511
|
-
res_magn += intmagn[:, xconverged].sum(1).max(1)
|
512
|
-
if xconverged.all():
|
513
|
-
break
|
514
|
-
|
515
|
-
xrefine = ~xconverged
|
516
|
-
xstart = xstart[xrefine]
|
517
|
-
xstop = xstop[xrefine]
|
518
|
-
xedge = np.linspace(xstart, xstop, radix + 1, axis=-1)
|
519
|
-
xstart = xedge[:, :-1].ravel()
|
520
|
-
xstop = xedge[:, 1:].ravel()
|
521
|
-
else:
|
522
|
-
warn("Integration did not converge after refinement")
|
523
|
-
|
524
|
-
res_shape = poly.shape + f_shape
|
525
|
-
return res_value.reshape(res_shape), res_error.reshape(res_shape)
|
268
|
+
def __getitem__(self, index):
|
269
|
+
"""Get a single basis function."""
|
270
|
+
return PiecewiseLegendrePolyFT(self._funcs[index])
|