pylibsparseir 0.7.4__cp312-cp312-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pylibsparseir/__init__.py +31 -0
- pylibsparseir/clean_build_artifacts.py +71 -0
- pylibsparseir/constants.py +39 -0
- pylibsparseir/core.py +641 -0
- pylibsparseir/ctypes_autogen.py +117 -0
- pylibsparseir/ctypes_wrapper.py +44 -0
- pylibsparseir/libsparse_ir_capi.dylib +0 -0
- pylibsparseir/sparseir.h +1874 -0
- pylibsparseir-0.7.4.dist-info/METADATA +209 -0
- pylibsparseir-0.7.4.dist-info/RECORD +12 -0
- pylibsparseir-0.7.4.dist-info/WHEEL +6 -0
- pylibsparseir-0.7.4.dist-info/licenses/LICENSE +21 -0
pylibsparseir/core.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core functionality for the SparseIR Python bindings.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import ctypes
|
|
8
|
+
from ctypes import c_int, c_double, c_int64, c_size_t, c_bool, POINTER, byref
|
|
9
|
+
from ctypes import CDLL
|
|
10
|
+
import numpy as np
|
|
11
|
+
import platform
|
|
12
|
+
|
|
13
|
+
# Enable only on Linux
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
import ctypes
|
|
17
|
+
import platform
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import ctypes
|
|
22
|
+
|
|
23
|
+
from .ctypes_wrapper import spir_kernel, spir_sve_result, spir_basis, spir_funcs, spir_sampling, spir_gemm_backend
|
|
24
|
+
from pylibsparseir.constants import COMPUTATION_SUCCESS, SPIR_ORDER_ROW_MAJOR, SPIR_ORDER_COLUMN_MAJOR, SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2, SPIR_STATISTICS_FERMIONIC, SPIR_STATISTICS_BOSONIC
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _find_library():
|
|
28
|
+
"""Find the SparseIR shared library."""
|
|
29
|
+
if sys.platform == "darwin":
|
|
30
|
+
libname = "libsparse_ir_capi.dylib"
|
|
31
|
+
elif sys.platform == "win32":
|
|
32
|
+
libname = "sparse_ir_capi.dll"
|
|
33
|
+
else:
|
|
34
|
+
libname = "libsparse_ir_capi.so"
|
|
35
|
+
|
|
36
|
+
# Try to find the library in common locations
|
|
37
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
38
|
+
search_paths = [
|
|
39
|
+
script_dir, # Same directory as this file
|
|
40
|
+
os.path.join(script_dir, "..", "build"),
|
|
41
|
+
os.path.join(script_dir, "..", "..", "build"),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
# Try to find workspace root by going up from package location
|
|
45
|
+
# The package might be installed in site-packages, so we need to go up several levels
|
|
46
|
+
current = script_dir
|
|
47
|
+
for _ in range(10): # Limit search depth
|
|
48
|
+
target_release = os.path.join(current, "target", "release")
|
|
49
|
+
target_debug = os.path.join(current, "target", "debug")
|
|
50
|
+
if os.path.exists(os.path.join(current, "sparse-ir-capi")) or os.path.exists(os.path.join(current, "Cargo.toml")):
|
|
51
|
+
# Found workspace root
|
|
52
|
+
search_paths.append(target_release)
|
|
53
|
+
search_paths.append(target_debug)
|
|
54
|
+
break
|
|
55
|
+
parent = os.path.dirname(current)
|
|
56
|
+
if parent == current: # Reached filesystem root
|
|
57
|
+
break
|
|
58
|
+
current = parent
|
|
59
|
+
|
|
60
|
+
# Also check common workspace locations relative to package
|
|
61
|
+
# From site-packages/pylibsparseir, workspace root is typically 6-7 levels up
|
|
62
|
+
for levels in range(3, 8):
|
|
63
|
+
candidate = script_dir
|
|
64
|
+
for _ in range(levels):
|
|
65
|
+
candidate = os.path.dirname(candidate)
|
|
66
|
+
candidate_release = os.path.join(candidate, "target", "release")
|
|
67
|
+
candidate_debug = os.path.join(candidate, "target", "debug")
|
|
68
|
+
if os.path.exists(candidate_release) or os.path.exists(candidate_debug):
|
|
69
|
+
search_paths.append(candidate_release)
|
|
70
|
+
search_paths.append(candidate_debug)
|
|
71
|
+
|
|
72
|
+
for path in search_paths:
|
|
73
|
+
libpath = os.path.join(path, libname)
|
|
74
|
+
if os.path.exists(libpath):
|
|
75
|
+
return libpath
|
|
76
|
+
|
|
77
|
+
raise RuntimeError(f"Could not find {libname} in {search_paths}")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Load the library
|
|
81
|
+
_blas_backend = None
|
|
82
|
+
try:
|
|
83
|
+
import scipy.linalg.cython_blas as blas
|
|
84
|
+
# dgemm capsule
|
|
85
|
+
# Get the PyCapsule objects for dgemm and zgemm
|
|
86
|
+
capsule = blas.__pyx_capi__["dgemm"]
|
|
87
|
+
capsule_z = blas.__pyx_capi__["zgemm"]
|
|
88
|
+
|
|
89
|
+
# Get the name of the PyCapsule (optional, but safer to be explicit)
|
|
90
|
+
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
|
|
91
|
+
ctypes.pythonapi.PyCapsule_GetName.argtypes = [ctypes.py_object]
|
|
92
|
+
name = ctypes.pythonapi.PyCapsule_GetName(capsule)
|
|
93
|
+
name_z = ctypes.pythonapi.PyCapsule_GetName(capsule_z)
|
|
94
|
+
# Extract the pointer from the PyCapsule
|
|
95
|
+
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
|
|
96
|
+
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
|
|
97
|
+
ctypes.py_object, ctypes.c_char_p]
|
|
98
|
+
ptr = ctypes.pythonapi.PyCapsule_GetPointer(capsule, name)
|
|
99
|
+
ptr_z = ctypes.pythonapi.PyCapsule_GetPointer(capsule_z, name_z)
|
|
100
|
+
|
|
101
|
+
_lib = ctypes.CDLL(_find_library())
|
|
102
|
+
|
|
103
|
+
# Create GEMM backend handle from SciPy BLAS functions
|
|
104
|
+
# Note: SciPy BLAS typically uses LP64 interface (32-bit integers)
|
|
105
|
+
_lib.spir_gemm_backend_new_from_fblas_lp64.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
106
|
+
_lib.spir_gemm_backend_new_from_fblas_lp64.restype = spir_gemm_backend
|
|
107
|
+
_blas_backend = _lib.spir_gemm_backend_new_from_fblas_lp64(ptr, ptr_z)
|
|
108
|
+
|
|
109
|
+
if _blas_backend is None:
|
|
110
|
+
raise RuntimeError("Failed to create BLAS backend handle")
|
|
111
|
+
|
|
112
|
+
if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
|
|
113
|
+
print(f"[core.py] Created SciPy BLAS backend handle")
|
|
114
|
+
print(f"[core.py] Registered SciPy BLAS dgemm @ {hex(ptr)}")
|
|
115
|
+
print(f"[core.py] Registered SciPy BLAS zgemm @ {hex(ptr_z)}")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise RuntimeError(f"Failed to load SparseIR library: {e}")
|
|
118
|
+
|
|
119
|
+
# Module-level variable to store the default BLAS backend
|
|
120
|
+
# Users can pass None to use default backend, or pass _blas_backend explicitly
|
|
121
|
+
_default_blas_backend = _blas_backend
|
|
122
|
+
|
|
123
|
+
def get_default_blas_backend():
|
|
124
|
+
"""Get the default BLAS backend handle (created from SciPy BLAS).
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
spir_gemm_backend: The default BLAS backend handle, or None if not available.
|
|
128
|
+
"""
|
|
129
|
+
return _default_blas_backend
|
|
130
|
+
|
|
131
|
+
def release_blas_backend(backend):
|
|
132
|
+
"""Release a BLAS backend handle.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
backend: The backend handle to release (can be None).
|
|
136
|
+
"""
|
|
137
|
+
if backend is not None:
|
|
138
|
+
_lib.spir_gemm_backend_release.argtypes = [spir_gemm_backend]
|
|
139
|
+
_lib.spir_gemm_backend_release.restype = None
|
|
140
|
+
_lib.spir_gemm_backend_release(backend)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class c_double_complex(ctypes.Structure):
|
|
144
|
+
"""complex is a c structure
|
|
145
|
+
https://docs.python.org/3/library/ctypes.html#module-ctypes suggests
|
|
146
|
+
to use ctypes.Structure to pass structures (and, therefore, complex)
|
|
147
|
+
See: https://stackoverflow.com/questions/13373291/complex-number-in-ctypes
|
|
148
|
+
"""
|
|
149
|
+
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def value(self):
|
|
153
|
+
return self.real+1j*self.imag # fields declared above
|
|
154
|
+
|
|
155
|
+
# Set up function prototypes using auto-generated bindings
|
|
156
|
+
try:
|
|
157
|
+
from .ctypes_autogen import FUNCTIONS, c_double_complex as _autogen_c_double_complex
|
|
158
|
+
# Use the generated c_double_complex if available, otherwise use the one defined below
|
|
159
|
+
# (They should be identical, but we keep the local definition for backward compatibility)
|
|
160
|
+
except ImportError:
|
|
161
|
+
# Fallback: if autogen file doesn't exist, use manual setup
|
|
162
|
+
FUNCTIONS = {}
|
|
163
|
+
print("WARNING: ctypes_autogen.py not found. Run tools/gen_ctypes.py to generate it.")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _normalize_type_string(type_str):
|
|
167
|
+
"""Normalize type string by replacing 'struct X' with 'X' for eval compatibility."""
|
|
168
|
+
# Special case: spir_gemm_backend is already a POINTER type, so POINTER(struct spir_gemm_backend)
|
|
169
|
+
# should map to just spir_gemm_backend, not POINTER(spir_gemm_backend)
|
|
170
|
+
if type_str == 'POINTER(struct spir_gemm_backend)':
|
|
171
|
+
return 'spir_gemm_backend'
|
|
172
|
+
|
|
173
|
+
# Replace 'struct spir_*' with 'spir_*' and 'struct Complex64' with 'c_double_complex'
|
|
174
|
+
# This allows eval() to work with the type_map
|
|
175
|
+
normalized = type_str.replace('struct spir_kernel', 'spir_kernel')
|
|
176
|
+
normalized = normalized.replace('struct spir_funcs', 'spir_funcs')
|
|
177
|
+
normalized = normalized.replace('struct spir_basis', 'spir_basis')
|
|
178
|
+
normalized = normalized.replace('struct spir_sampling', 'spir_sampling')
|
|
179
|
+
normalized = normalized.replace('struct spir_sve_result', 'spir_sve_result')
|
|
180
|
+
normalized = normalized.replace('struct spir_gemm_backend', 'spir_gemm_backend')
|
|
181
|
+
normalized = normalized.replace('struct Complex64', 'c_double_complex')
|
|
182
|
+
return normalized
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _setup_prototypes():
|
|
186
|
+
"""Set up function prototypes from auto-generated bindings."""
|
|
187
|
+
if not FUNCTIONS:
|
|
188
|
+
# Fallback to manual setup if generation failed
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
# Import necessary types into local namespace for eval
|
|
192
|
+
from ctypes import c_int, c_double, c_int64, c_size_t, c_bool, POINTER, c_char_p
|
|
193
|
+
from .ctypes_wrapper import spir_kernel, spir_funcs, spir_basis, spir_sampling, spir_sve_result, spir_gemm_backend
|
|
194
|
+
# Use the c_double_complex from this module (core.py), not from ctypes_autogen
|
|
195
|
+
# This ensures type consistency
|
|
196
|
+
|
|
197
|
+
# Type mapping for eval
|
|
198
|
+
type_map = {
|
|
199
|
+
'c_int': c_int, 'c_double': c_double, 'c_int64': c_int64,
|
|
200
|
+
'c_size_t': c_size_t, 'c_bool': c_bool,
|
|
201
|
+
'POINTER': POINTER, 'c_char_p': c_char_p,
|
|
202
|
+
'spir_kernel': spir_kernel, 'spir_funcs': spir_funcs,
|
|
203
|
+
'spir_basis': spir_basis, 'spir_sampling': spir_sampling,
|
|
204
|
+
'spir_sve_result': spir_sve_result,
|
|
205
|
+
'spir_gemm_backend': spir_gemm_backend,
|
|
206
|
+
'c_double_complex': c_double_complex, # Use the one defined in this module
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
# Apply generated prototypes to the library
|
|
210
|
+
for name, (restype_str, argtypes_list) in FUNCTIONS.items():
|
|
211
|
+
if not hasattr(_lib, name):
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
func = getattr(_lib, name)
|
|
215
|
+
try:
|
|
216
|
+
# Evaluate restype
|
|
217
|
+
if restype_str == 'None':
|
|
218
|
+
func.restype = None
|
|
219
|
+
else:
|
|
220
|
+
normalized_restype = _normalize_type_string(restype_str)
|
|
221
|
+
func.restype = eval(normalized_restype, globals(), type_map)
|
|
222
|
+
|
|
223
|
+
# Evaluate argtypes
|
|
224
|
+
evaluated_argtypes = []
|
|
225
|
+
for i, argtype_str in enumerate(argtypes_list):
|
|
226
|
+
normalized_argtype = _normalize_type_string(argtype_str)
|
|
227
|
+
try:
|
|
228
|
+
evaluated_argtypes.append(eval(normalized_argtype, globals(), type_map))
|
|
229
|
+
except (NameError, AttributeError, SyntaxError) as e:
|
|
230
|
+
# If evaluation fails for this argument, skip setting argtypes for this function
|
|
231
|
+
if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
|
|
232
|
+
print(f"WARNING: Could not evaluate argtype {i} '{argtype_str}' (normalized: '{normalized_argtype}') for {name}: {e}")
|
|
233
|
+
raise # Re-raise to skip this function
|
|
234
|
+
func.argtypes = evaluated_argtypes
|
|
235
|
+
except (NameError, AttributeError, SyntaxError) as e:
|
|
236
|
+
# Skip functions that can't be evaluated (might be missing types)
|
|
237
|
+
if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
|
|
238
|
+
print(f"WARNING: Could not set prototype for {name}: {e}")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
_setup_prototypes()
|
|
242
|
+
|
|
243
|
+
# Python wrapper functions
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def logistic_kernel_new(lambda_val):
|
|
247
|
+
"""Create a new logistic kernel."""
|
|
248
|
+
status = c_int()
|
|
249
|
+
kernel = _lib.spir_logistic_kernel_new(lambda_val, byref(status))
|
|
250
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
251
|
+
raise RuntimeError(f"Failed to create logistic kernel: {status.value}")
|
|
252
|
+
return kernel
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def reg_bose_kernel_new(lambda_val):
|
|
256
|
+
"""Create a new regularized bosonic kernel."""
|
|
257
|
+
status = c_int()
|
|
258
|
+
kernel = _lib.spir_reg_bose_kernel_new(lambda_val, byref(status))
|
|
259
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
260
|
+
raise RuntimeError(
|
|
261
|
+
f"Failed to create regularized bosonic kernel: {status.value}")
|
|
262
|
+
return kernel
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def sve_result_new(kernel, epsilon, cutoff=None, lmax=None, n_gauss=None, Twork=None):
|
|
266
|
+
"""Create a new SVE result.
|
|
267
|
+
|
|
268
|
+
Note: cutoff parameter is deprecated and ignored (C-API doesn't have it).
|
|
269
|
+
It's kept for backward compatibility but not passed to C-API.
|
|
270
|
+
"""
|
|
271
|
+
# Validate epsilon
|
|
272
|
+
if epsilon <= 0:
|
|
273
|
+
raise RuntimeError(
|
|
274
|
+
f"Failed to create SVE result: epsilon must be positive, got {epsilon}")
|
|
275
|
+
|
|
276
|
+
# Note: cutoff parameter was removed from C-API, kept for backward compatibility
|
|
277
|
+
if lmax is None:
|
|
278
|
+
lmax = -1
|
|
279
|
+
if n_gauss is None:
|
|
280
|
+
n_gauss = -1
|
|
281
|
+
if Twork is None:
|
|
282
|
+
Twork = SPIR_TWORK_FLOAT64X2
|
|
283
|
+
|
|
284
|
+
status = c_int()
|
|
285
|
+
# C-API signature: spir_sve_result_new(kernel, epsilon, lmax, n_gauss, Twork, status)
|
|
286
|
+
sve = _lib.spir_sve_result_new(
|
|
287
|
+
kernel, c_double(epsilon), c_int(lmax), c_int(n_gauss), c_int(Twork), byref(status))
|
|
288
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
289
|
+
raise RuntimeError(f"Failed to create SVE result: {status.value}")
|
|
290
|
+
return sve
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def sve_result_get_size(sve):
|
|
294
|
+
"""Get the size of an SVE result."""
|
|
295
|
+
size = c_int()
|
|
296
|
+
status = _lib.spir_sve_result_get_size(sve, byref(size))
|
|
297
|
+
if status != COMPUTATION_SUCCESS:
|
|
298
|
+
raise RuntimeError(f"Failed to get SVE result size: {status}")
|
|
299
|
+
return size.value
|
|
300
|
+
|
|
301
|
+
def sve_result_truncate(sve, epsilon, max_size):
|
|
302
|
+
"""Truncate an SVE result."""
|
|
303
|
+
status = c_int()
|
|
304
|
+
sve = _lib.spir_sve_result_truncate(sve, epsilon, max_size, byref(status))
|
|
305
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
306
|
+
raise RuntimeError(f"Failed to truncate SVE result: {status.value}")
|
|
307
|
+
return sve
|
|
308
|
+
|
|
309
|
+
def sve_result_get_svals(sve):
|
|
310
|
+
"""Get the singular values from an SVE result."""
|
|
311
|
+
size = sve_result_get_size(sve)
|
|
312
|
+
svals = np.zeros(size, dtype=np.float64)
|
|
313
|
+
status = _lib.spir_sve_result_get_svals(
|
|
314
|
+
sve, svals.ctypes.data_as(POINTER(c_double)))
|
|
315
|
+
if status != COMPUTATION_SUCCESS:
|
|
316
|
+
raise RuntimeError(f"Failed to get singular values: {status}")
|
|
317
|
+
return svals
|
|
318
|
+
|
|
319
|
+
def basis_new(statistics, beta, omega_max, epsilon, kernel, sve, max_size):
|
|
320
|
+
"""Create a new basis."""
|
|
321
|
+
status = c_int()
|
|
322
|
+
basis = _lib.spir_basis_new(
|
|
323
|
+
statistics, beta, omega_max, epsilon, kernel, sve, max_size, byref(status)
|
|
324
|
+
)
|
|
325
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
326
|
+
raise RuntimeError(f"Failed to create basis: {status.value}")
|
|
327
|
+
return basis
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def basis_get_size(basis):
|
|
331
|
+
"""Get the size of a basis."""
|
|
332
|
+
size = c_int()
|
|
333
|
+
status = _lib.spir_basis_get_size(basis, byref(size))
|
|
334
|
+
if status != COMPUTATION_SUCCESS:
|
|
335
|
+
raise RuntimeError(f"Failed to get basis size: {status}")
|
|
336
|
+
return size.value
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def basis_get_svals(basis):
|
|
340
|
+
"""Get the singular values of a basis."""
|
|
341
|
+
size = basis_get_size(basis)
|
|
342
|
+
svals = np.zeros(size, dtype=np.float64)
|
|
343
|
+
status = _lib.spir_basis_get_svals(
|
|
344
|
+
basis, svals.ctypes.data_as(POINTER(c_double)))
|
|
345
|
+
if status != COMPUTATION_SUCCESS:
|
|
346
|
+
raise RuntimeError(f"Failed to get singular values: {status}")
|
|
347
|
+
return svals
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def basis_get_stats(basis):
|
|
351
|
+
"""Get the statistics type of a basis."""
|
|
352
|
+
stats = c_int()
|
|
353
|
+
status = _lib.spir_basis_get_stats(basis, byref(stats))
|
|
354
|
+
if status != COMPUTATION_SUCCESS:
|
|
355
|
+
raise RuntimeError(f"Failed to get basis statistics: {status}")
|
|
356
|
+
return stats.value
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def basis_get_u(basis):
|
|
360
|
+
"""Get the imaginary-time basis functions."""
|
|
361
|
+
status = c_int()
|
|
362
|
+
funcs = _lib.spir_basis_get_u(basis, byref(status))
|
|
363
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
364
|
+
raise RuntimeError(f"Failed to get u basis functions: {status.value}")
|
|
365
|
+
return funcs
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def basis_get_v(basis):
|
|
369
|
+
"""Get the real-frequency basis functions."""
|
|
370
|
+
status = c_int()
|
|
371
|
+
funcs = _lib.spir_basis_get_v(basis, byref(status))
|
|
372
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
373
|
+
raise RuntimeError(f"Failed to get v basis functions: {status.value}")
|
|
374
|
+
return funcs
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def basis_get_uhat(basis):
|
|
378
|
+
"""Get the Matsubara frequency basis functions."""
|
|
379
|
+
status = c_int()
|
|
380
|
+
funcs = _lib.spir_basis_get_uhat(basis, byref(status))
|
|
381
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
382
|
+
raise RuntimeError(
|
|
383
|
+
f"Failed to get uhat basis functions: {status.value}")
|
|
384
|
+
return funcs
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def funcs_get_size(funcs):
|
|
388
|
+
"""Get the size of a basis function set."""
|
|
389
|
+
size = c_int()
|
|
390
|
+
status = _lib.spir_funcs_get_size(funcs, byref(size))
|
|
391
|
+
if status != COMPUTATION_SUCCESS:
|
|
392
|
+
raise RuntimeError(f"Failed to get function size: {status}")
|
|
393
|
+
return size.value
|
|
394
|
+
|
|
395
|
+
# TODO: Rename funcs_eval_single
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def funcs_eval_single_float64(funcs, x):
|
|
399
|
+
"""Evaluate basis functions at a single point."""
|
|
400
|
+
# Get number of functions
|
|
401
|
+
size = c_int()
|
|
402
|
+
status = _lib.spir_funcs_get_size(funcs, byref(size))
|
|
403
|
+
if status != COMPUTATION_SUCCESS:
|
|
404
|
+
raise RuntimeError(f"Failed to get function size: {status}")
|
|
405
|
+
|
|
406
|
+
# Prepare output array
|
|
407
|
+
out = np.zeros(size.value, dtype=np.float64)
|
|
408
|
+
|
|
409
|
+
# Evaluate
|
|
410
|
+
status = _lib.spir_funcs_eval(
|
|
411
|
+
funcs, c_double(x),
|
|
412
|
+
out.ctypes.data_as(POINTER(c_double))
|
|
413
|
+
)
|
|
414
|
+
if status != COMPUTATION_SUCCESS:
|
|
415
|
+
raise RuntimeError(f"Failed to evaluate functions: {status}")
|
|
416
|
+
|
|
417
|
+
return out
|
|
418
|
+
|
|
419
|
+
# TODO: Rename to funcs_eval_matsu_single
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def funcs_eval_single_complex128(funcs, x):
|
|
423
|
+
"""Evaluate basis functions at a single point."""
|
|
424
|
+
# Get number of functions
|
|
425
|
+
size = c_int()
|
|
426
|
+
status = _lib.spir_funcs_get_size(funcs, byref(size))
|
|
427
|
+
if status != COMPUTATION_SUCCESS:
|
|
428
|
+
raise RuntimeError(f"Failed to get function size: {status}")
|
|
429
|
+
|
|
430
|
+
# Prepare output array
|
|
431
|
+
out = np.zeros(size.value, dtype=np.complex128)
|
|
432
|
+
|
|
433
|
+
# Evaluate
|
|
434
|
+
status = _lib.spir_funcs_eval_matsu(
|
|
435
|
+
funcs, c_int64(x),
|
|
436
|
+
out.ctypes.data_as(POINTER(c_double_complex))
|
|
437
|
+
)
|
|
438
|
+
if status != COMPUTATION_SUCCESS:
|
|
439
|
+
raise RuntimeError(f"Failed to evaluate functions: {status}")
|
|
440
|
+
|
|
441
|
+
return out
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def funcs_get_n_knots(funcs):
|
|
445
|
+
"""Get the number of knots of the underlying piecewise Legendre polynomial."""
|
|
446
|
+
n_knots = c_int()
|
|
447
|
+
status = _lib.spir_funcs_get_n_knots(funcs, byref(n_knots))
|
|
448
|
+
if status != COMPUTATION_SUCCESS:
|
|
449
|
+
raise RuntimeError(f"Failed to get number of knots: {status}")
|
|
450
|
+
return n_knots.value
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def funcs_get_knots(funcs):
|
|
454
|
+
"""Get the knots of the underlying piecewise Legendre polynomial."""
|
|
455
|
+
n_knots = funcs_get_n_knots(funcs)
|
|
456
|
+
knots = np.zeros(n_knots, dtype=np.float64)
|
|
457
|
+
status = _lib.spir_funcs_get_knots(
|
|
458
|
+
funcs, knots.ctypes.data_as(POINTER(c_double)))
|
|
459
|
+
if status != COMPUTATION_SUCCESS:
|
|
460
|
+
raise RuntimeError(f"Failed to get knots: {status}")
|
|
461
|
+
return knots
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def basis_get_default_tau_sampling_points(basis):
|
|
465
|
+
"""Get default tau sampling points for a basis."""
|
|
466
|
+
# Get number of points
|
|
467
|
+
n_points = c_int()
|
|
468
|
+
status = _lib.spir_basis_get_n_default_taus(basis, byref(n_points))
|
|
469
|
+
if status != COMPUTATION_SUCCESS:
|
|
470
|
+
raise RuntimeError(
|
|
471
|
+
f"Failed to get number of default tau points: {status}")
|
|
472
|
+
|
|
473
|
+
# Get the points
|
|
474
|
+
points = np.zeros(n_points.value, dtype=np.float64)
|
|
475
|
+
status = _lib.spir_basis_get_default_taus(
|
|
476
|
+
basis, points.ctypes.data_as(POINTER(c_double)))
|
|
477
|
+
if status != COMPUTATION_SUCCESS:
|
|
478
|
+
raise RuntimeError(f"Failed to get default tau points: {status}")
|
|
479
|
+
|
|
480
|
+
return points
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def basis_get_default_tau_sampling_points_ext(basis, n_points):
|
|
484
|
+
"""Get default tau sampling points for a basis."""
|
|
485
|
+
points = np.zeros(n_points, dtype=np.float64)
|
|
486
|
+
n_points_returned = c_int()
|
|
487
|
+
status = _lib.spir_basis_get_default_taus_ext(
|
|
488
|
+
basis, n_points, points.ctypes.data_as(POINTER(c_double)), byref(n_points_returned))
|
|
489
|
+
if status != COMPUTATION_SUCCESS:
|
|
490
|
+
raise RuntimeError(f"Failed to get default tau points: {status}")
|
|
491
|
+
return points
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def basis_get_default_omega_sampling_points(basis):
|
|
495
|
+
"""Get default omega (real frequency) sampling points for a basis."""
|
|
496
|
+
# Get number of points
|
|
497
|
+
n_points = c_int()
|
|
498
|
+
status = _lib.spir_basis_get_n_default_ws(basis, byref(n_points))
|
|
499
|
+
if status != COMPUTATION_SUCCESS:
|
|
500
|
+
raise RuntimeError(
|
|
501
|
+
f"Failed to get number of default omega points: {status}")
|
|
502
|
+
|
|
503
|
+
# Get the points
|
|
504
|
+
points = np.zeros(n_points.value, dtype=np.float64)
|
|
505
|
+
status = _lib.spir_basis_get_default_ws(
|
|
506
|
+
basis, points.ctypes.data_as(POINTER(c_double)))
|
|
507
|
+
if status != COMPUTATION_SUCCESS:
|
|
508
|
+
raise RuntimeError(f"Failed to get default omega points: {status}")
|
|
509
|
+
|
|
510
|
+
return points
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def basis_get_default_matsubara_sampling_points(basis, positive_only=False):
|
|
514
|
+
"""Get default Matsubara sampling points for a basis."""
|
|
515
|
+
# Get number of points
|
|
516
|
+
n_points = c_int()
|
|
517
|
+
status = _lib.spir_basis_get_n_default_matsus(
|
|
518
|
+
basis, c_int(1 if positive_only else 0), byref(n_points))
|
|
519
|
+
if status != COMPUTATION_SUCCESS:
|
|
520
|
+
raise RuntimeError(
|
|
521
|
+
f"Failed to get number of default Matsubara points: {status}")
|
|
522
|
+
|
|
523
|
+
# Get the points
|
|
524
|
+
points = np.zeros(n_points.value, dtype=np.int64)
|
|
525
|
+
status = _lib.spir_basis_get_default_matsus(basis, c_int(
|
|
526
|
+
1 if positive_only else 0), points.ctypes.data_as(POINTER(c_int64)))
|
|
527
|
+
if status != COMPUTATION_SUCCESS:
|
|
528
|
+
raise RuntimeError(f"Failed to get default Matsubara points: {status}")
|
|
529
|
+
|
|
530
|
+
return points
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def basis_get_n_default_matsus_ext(basis, n_points, positive_only):
|
|
534
|
+
"""Get the number of default Matsubara sampling points for a basis."""
|
|
535
|
+
n_points_returned = c_int()
|
|
536
|
+
status = _lib.spir_basis_get_n_default_matsus_ext(
|
|
537
|
+
basis, c_int(1 if positive_only else 0), n_points, byref(n_points_returned))
|
|
538
|
+
if status != COMPUTATION_SUCCESS:
|
|
539
|
+
raise RuntimeError(
|
|
540
|
+
f"Failed to get number of default Matsubara points: {status}")
|
|
541
|
+
return n_points_returned.value
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def basis_get_default_matsus_ext(basis, positive_only, points):
|
|
545
|
+
n_points = len(points)
|
|
546
|
+
n_points_returned = c_int()
|
|
547
|
+
status = _lib.spir_basis_get_default_matsus_ext(basis, c_int(
|
|
548
|
+
1 if positive_only else 0), c_int(0), n_points, points.ctypes.data_as(POINTER(c_int64)), byref(n_points_returned))
|
|
549
|
+
if status != COMPUTATION_SUCCESS:
|
|
550
|
+
raise RuntimeError(f"Failed to get default Matsubara points: {status}")
|
|
551
|
+
return points
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def tau_sampling_new(basis, sampling_points=None):
|
|
555
|
+
"""Create a new tau sampling object."""
|
|
556
|
+
if sampling_points is None:
|
|
557
|
+
sampling_points = basis_get_default_tau_sampling_points(basis)
|
|
558
|
+
|
|
559
|
+
sampling_points = np.asarray(sampling_points, dtype=np.float64)
|
|
560
|
+
n_points = len(sampling_points)
|
|
561
|
+
|
|
562
|
+
status = c_int()
|
|
563
|
+
sampling = _lib.spir_tau_sampling_new(
|
|
564
|
+
basis, n_points,
|
|
565
|
+
sampling_points.ctypes.data_as(POINTER(c_double)),
|
|
566
|
+
byref(status)
|
|
567
|
+
)
|
|
568
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
569
|
+
raise RuntimeError(f"Failed to create tau sampling: {status.value}")
|
|
570
|
+
|
|
571
|
+
return sampling
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _statistics_to_c(statistics):
|
|
575
|
+
"""Convert statistics to c type."""
|
|
576
|
+
if statistics == "F":
|
|
577
|
+
return SPIR_STATISTICS_FERMIONIC
|
|
578
|
+
elif statistics == "B":
|
|
579
|
+
return SPIR_STATISTICS_BOSONIC
|
|
580
|
+
else:
|
|
581
|
+
raise ValueError(f"Invalid statistics: {statistics}")
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def tau_sampling_new_with_matrix(basis, statistics, sampling_points, matrix):
|
|
585
|
+
"""Create a new tau sampling object with a matrix."""
|
|
586
|
+
status = c_int()
|
|
587
|
+
sampling = _lib.spir_tau_sampling_new_with_matrix(
|
|
588
|
+
SPIR_ORDER_ROW_MAJOR,
|
|
589
|
+
_statistics_to_c(statistics),
|
|
590
|
+
basis.size,
|
|
591
|
+
sampling_points.size,
|
|
592
|
+
sampling_points.ctypes.data_as(POINTER(c_double)),
|
|
593
|
+
matrix.ctypes.data_as(POINTER(c_double)),
|
|
594
|
+
byref(status)
|
|
595
|
+
)
|
|
596
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
597
|
+
raise RuntimeError(f"Failed to create tau sampling: {status.value}")
|
|
598
|
+
|
|
599
|
+
return sampling
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def matsubara_sampling_new(basis, positive_only=False, sampling_points=None):
|
|
603
|
+
"""Create a new Matsubara sampling object."""
|
|
604
|
+
if sampling_points is None:
|
|
605
|
+
sampling_points = basis_get_default_matsubara_sampling_points(
|
|
606
|
+
basis, positive_only)
|
|
607
|
+
|
|
608
|
+
sampling_points = np.asarray(sampling_points, dtype=np.int64)
|
|
609
|
+
n_points = len(sampling_points)
|
|
610
|
+
|
|
611
|
+
status = c_int()
|
|
612
|
+
sampling = _lib.spir_matsu_sampling_new(
|
|
613
|
+
basis, c_bool(positive_only), n_points,
|
|
614
|
+
sampling_points.ctypes.data_as(POINTER(c_int64)),
|
|
615
|
+
byref(status)
|
|
616
|
+
)
|
|
617
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
618
|
+
raise RuntimeError(
|
|
619
|
+
f"Failed to create Matsubara sampling: {status.value}")
|
|
620
|
+
|
|
621
|
+
return sampling
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def matsubara_sampling_new_with_matrix(statistics, basis_size, positive_only, sampling_points, matrix):
|
|
625
|
+
"""Create a new Matsubara sampling object with a matrix."""
|
|
626
|
+
status = c_int()
|
|
627
|
+
sampling = _lib.spir_matsu_sampling_new_with_matrix(
|
|
628
|
+
SPIR_ORDER_ROW_MAJOR, # order
|
|
629
|
+
_statistics_to_c(statistics), # statistics
|
|
630
|
+
c_int(basis_size), # basis_size
|
|
631
|
+
c_bool(positive_only), # positive_only
|
|
632
|
+
c_int(len(sampling_points)), # num_points
|
|
633
|
+
sampling_points.ctypes.data_as(POINTER(c_int64)), # points
|
|
634
|
+
matrix.ctypes.data_as(POINTER(c_double_complex)), # matrix
|
|
635
|
+
byref(status) # status
|
|
636
|
+
)
|
|
637
|
+
if status.value != COMPUTATION_SUCCESS:
|
|
638
|
+
raise RuntimeError(
|
|
639
|
+
f"Failed to create Matsubara sampling: {status.value}")
|
|
640
|
+
|
|
641
|
+
return sampling
|