pylibsparseir 0.7.4__cp312-cp312-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pylibsparseir/core.py ADDED
@@ -0,0 +1,641 @@
1
+ """
2
+ Core functionality for the SparseIR Python bindings.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import ctypes
8
+ from ctypes import c_int, c_double, c_int64, c_size_t, c_bool, POINTER, byref
9
+ from ctypes import CDLL
10
+ import numpy as np
11
+ import platform
12
+
13
+ # Enable only on Linux
14
+ import os
15
+ import sys
16
+ import ctypes
17
+ import platform
18
+
19
+ import os
20
+ import sys
21
+ import ctypes
22
+
23
+ from .ctypes_wrapper import spir_kernel, spir_sve_result, spir_basis, spir_funcs, spir_sampling, spir_gemm_backend
24
+ from pylibsparseir.constants import COMPUTATION_SUCCESS, SPIR_ORDER_ROW_MAJOR, SPIR_ORDER_COLUMN_MAJOR, SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2, SPIR_STATISTICS_FERMIONIC, SPIR_STATISTICS_BOSONIC
25
+
26
+
27
+ def _find_library():
28
+ """Find the SparseIR shared library."""
29
+ if sys.platform == "darwin":
30
+ libname = "libsparse_ir_capi.dylib"
31
+ elif sys.platform == "win32":
32
+ libname = "sparse_ir_capi.dll"
33
+ else:
34
+ libname = "libsparse_ir_capi.so"
35
+
36
+ # Try to find the library in common locations
37
+ script_dir = os.path.dirname(os.path.abspath(__file__))
38
+ search_paths = [
39
+ script_dir, # Same directory as this file
40
+ os.path.join(script_dir, "..", "build"),
41
+ os.path.join(script_dir, "..", "..", "build"),
42
+ ]
43
+
44
+ # Try to find workspace root by going up from package location
45
+ # The package might be installed in site-packages, so we need to go up several levels
46
+ current = script_dir
47
+ for _ in range(10): # Limit search depth
48
+ target_release = os.path.join(current, "target", "release")
49
+ target_debug = os.path.join(current, "target", "debug")
50
+ if os.path.exists(os.path.join(current, "sparse-ir-capi")) or os.path.exists(os.path.join(current, "Cargo.toml")):
51
+ # Found workspace root
52
+ search_paths.append(target_release)
53
+ search_paths.append(target_debug)
54
+ break
55
+ parent = os.path.dirname(current)
56
+ if parent == current: # Reached filesystem root
57
+ break
58
+ current = parent
59
+
60
+ # Also check common workspace locations relative to package
61
+ # From site-packages/pylibsparseir, workspace root is typically 6-7 levels up
62
+ for levels in range(3, 8):
63
+ candidate = script_dir
64
+ for _ in range(levels):
65
+ candidate = os.path.dirname(candidate)
66
+ candidate_release = os.path.join(candidate, "target", "release")
67
+ candidate_debug = os.path.join(candidate, "target", "debug")
68
+ if os.path.exists(candidate_release) or os.path.exists(candidate_debug):
69
+ search_paths.append(candidate_release)
70
+ search_paths.append(candidate_debug)
71
+
72
+ for path in search_paths:
73
+ libpath = os.path.join(path, libname)
74
+ if os.path.exists(libpath):
75
+ return libpath
76
+
77
+ raise RuntimeError(f"Could not find {libname} in {search_paths}")
78
+
79
+
80
+ # Load the library
81
+ _blas_backend = None
82
+ try:
83
+ import scipy.linalg.cython_blas as blas
84
+ # dgemm capsule
85
+ # Get the PyCapsule objects for dgemm and zgemm
86
+ capsule = blas.__pyx_capi__["dgemm"]
87
+ capsule_z = blas.__pyx_capi__["zgemm"]
88
+
89
+ # Get the name of the PyCapsule (optional, but safer to be explicit)
90
+ ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
91
+ ctypes.pythonapi.PyCapsule_GetName.argtypes = [ctypes.py_object]
92
+ name = ctypes.pythonapi.PyCapsule_GetName(capsule)
93
+ name_z = ctypes.pythonapi.PyCapsule_GetName(capsule_z)
94
+ # Extract the pointer from the PyCapsule
95
+ ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
96
+ ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
97
+ ctypes.py_object, ctypes.c_char_p]
98
+ ptr = ctypes.pythonapi.PyCapsule_GetPointer(capsule, name)
99
+ ptr_z = ctypes.pythonapi.PyCapsule_GetPointer(capsule_z, name_z)
100
+
101
+ _lib = ctypes.CDLL(_find_library())
102
+
103
+ # Create GEMM backend handle from SciPy BLAS functions
104
+ # Note: SciPy BLAS typically uses LP64 interface (32-bit integers)
105
+ _lib.spir_gemm_backend_new_from_fblas_lp64.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
106
+ _lib.spir_gemm_backend_new_from_fblas_lp64.restype = spir_gemm_backend
107
+ _blas_backend = _lib.spir_gemm_backend_new_from_fblas_lp64(ptr, ptr_z)
108
+
109
+ if _blas_backend is None:
110
+ raise RuntimeError("Failed to create BLAS backend handle")
111
+
112
+ if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
113
+ print(f"[core.py] Created SciPy BLAS backend handle")
114
+ print(f"[core.py] Registered SciPy BLAS dgemm @ {hex(ptr)}")
115
+ print(f"[core.py] Registered SciPy BLAS zgemm @ {hex(ptr_z)}")
116
+ except Exception as e:
117
+ raise RuntimeError(f"Failed to load SparseIR library: {e}")
118
+
119
+ # Module-level variable to store the default BLAS backend
120
+ # Users can pass None to use default backend, or pass _blas_backend explicitly
121
+ _default_blas_backend = _blas_backend
122
+
123
+ def get_default_blas_backend():
124
+ """Get the default BLAS backend handle (created from SciPy BLAS).
125
+
126
+ Returns:
127
+ spir_gemm_backend: The default BLAS backend handle, or None if not available.
128
+ """
129
+ return _default_blas_backend
130
+
131
+ def release_blas_backend(backend):
132
+ """Release a BLAS backend handle.
133
+
134
+ Args:
135
+ backend: The backend handle to release (can be None).
136
+ """
137
+ if backend is not None:
138
+ _lib.spir_gemm_backend_release.argtypes = [spir_gemm_backend]
139
+ _lib.spir_gemm_backend_release.restype = None
140
+ _lib.spir_gemm_backend_release(backend)
141
+
142
+
143
+ class c_double_complex(ctypes.Structure):
144
+ """complex is a c structure
145
+ https://docs.python.org/3/library/ctypes.html#module-ctypes suggests
146
+ to use ctypes.Structure to pass structures (and, therefore, complex)
147
+ See: https://stackoverflow.com/questions/13373291/complex-number-in-ctypes
148
+ """
149
+ _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
150
+
151
+ @property
152
+ def value(self):
153
+ return self.real+1j*self.imag # fields declared above
154
+
155
+ # Set up function prototypes using auto-generated bindings
156
+ try:
157
+ from .ctypes_autogen import FUNCTIONS, c_double_complex as _autogen_c_double_complex
158
+ # Use the generated c_double_complex if available, otherwise use the one defined below
159
+ # (They should be identical, but we keep the local definition for backward compatibility)
160
+ except ImportError:
161
+ # Fallback: if autogen file doesn't exist, use manual setup
162
+ FUNCTIONS = {}
163
+ print("WARNING: ctypes_autogen.py not found. Run tools/gen_ctypes.py to generate it.")
164
+
165
+
166
+ def _normalize_type_string(type_str):
167
+ """Normalize type string by replacing 'struct X' with 'X' for eval compatibility."""
168
+ # Special case: spir_gemm_backend is already a POINTER type, so POINTER(struct spir_gemm_backend)
169
+ # should map to just spir_gemm_backend, not POINTER(spir_gemm_backend)
170
+ if type_str == 'POINTER(struct spir_gemm_backend)':
171
+ return 'spir_gemm_backend'
172
+
173
+ # Replace 'struct spir_*' with 'spir_*' and 'struct Complex64' with 'c_double_complex'
174
+ # This allows eval() to work with the type_map
175
+ normalized = type_str.replace('struct spir_kernel', 'spir_kernel')
176
+ normalized = normalized.replace('struct spir_funcs', 'spir_funcs')
177
+ normalized = normalized.replace('struct spir_basis', 'spir_basis')
178
+ normalized = normalized.replace('struct spir_sampling', 'spir_sampling')
179
+ normalized = normalized.replace('struct spir_sve_result', 'spir_sve_result')
180
+ normalized = normalized.replace('struct spir_gemm_backend', 'spir_gemm_backend')
181
+ normalized = normalized.replace('struct Complex64', 'c_double_complex')
182
+ return normalized
183
+
184
+
185
+ def _setup_prototypes():
186
+ """Set up function prototypes from auto-generated bindings."""
187
+ if not FUNCTIONS:
188
+ # Fallback to manual setup if generation failed
189
+ return
190
+
191
+ # Import necessary types into local namespace for eval
192
+ from ctypes import c_int, c_double, c_int64, c_size_t, c_bool, POINTER, c_char_p
193
+ from .ctypes_wrapper import spir_kernel, spir_funcs, spir_basis, spir_sampling, spir_sve_result, spir_gemm_backend
194
+ # Use the c_double_complex from this module (core.py), not from ctypes_autogen
195
+ # This ensures type consistency
196
+
197
+ # Type mapping for eval
198
+ type_map = {
199
+ 'c_int': c_int, 'c_double': c_double, 'c_int64': c_int64,
200
+ 'c_size_t': c_size_t, 'c_bool': c_bool,
201
+ 'POINTER': POINTER, 'c_char_p': c_char_p,
202
+ 'spir_kernel': spir_kernel, 'spir_funcs': spir_funcs,
203
+ 'spir_basis': spir_basis, 'spir_sampling': spir_sampling,
204
+ 'spir_sve_result': spir_sve_result,
205
+ 'spir_gemm_backend': spir_gemm_backend,
206
+ 'c_double_complex': c_double_complex, # Use the one defined in this module
207
+ }
208
+
209
+ # Apply generated prototypes to the library
210
+ for name, (restype_str, argtypes_list) in FUNCTIONS.items():
211
+ if not hasattr(_lib, name):
212
+ continue
213
+
214
+ func = getattr(_lib, name)
215
+ try:
216
+ # Evaluate restype
217
+ if restype_str == 'None':
218
+ func.restype = None
219
+ else:
220
+ normalized_restype = _normalize_type_string(restype_str)
221
+ func.restype = eval(normalized_restype, globals(), type_map)
222
+
223
+ # Evaluate argtypes
224
+ evaluated_argtypes = []
225
+ for i, argtype_str in enumerate(argtypes_list):
226
+ normalized_argtype = _normalize_type_string(argtype_str)
227
+ try:
228
+ evaluated_argtypes.append(eval(normalized_argtype, globals(), type_map))
229
+ except (NameError, AttributeError, SyntaxError) as e:
230
+ # If evaluation fails for this argument, skip setting argtypes for this function
231
+ if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
232
+ print(f"WARNING: Could not evaluate argtype {i} '{argtype_str}' (normalized: '{normalized_argtype}') for {name}: {e}")
233
+ raise # Re-raise to skip this function
234
+ func.argtypes = evaluated_argtypes
235
+ except (NameError, AttributeError, SyntaxError) as e:
236
+ # Skip functions that can't be evaluated (might be missing types)
237
+ if os.environ.get("SPARSEIR_DEBUG", "").lower() in ("1", "true", "yes", "on"):
238
+ print(f"WARNING: Could not set prototype for {name}: {e}")
239
+
240
+
241
+ _setup_prototypes()
242
+
243
+ # Python wrapper functions
244
+
245
+
246
+ def logistic_kernel_new(lambda_val):
247
+ """Create a new logistic kernel."""
248
+ status = c_int()
249
+ kernel = _lib.spir_logistic_kernel_new(lambda_val, byref(status))
250
+ if status.value != COMPUTATION_SUCCESS:
251
+ raise RuntimeError(f"Failed to create logistic kernel: {status.value}")
252
+ return kernel
253
+
254
+
255
+ def reg_bose_kernel_new(lambda_val):
256
+ """Create a new regularized bosonic kernel."""
257
+ status = c_int()
258
+ kernel = _lib.spir_reg_bose_kernel_new(lambda_val, byref(status))
259
+ if status.value != COMPUTATION_SUCCESS:
260
+ raise RuntimeError(
261
+ f"Failed to create regularized bosonic kernel: {status.value}")
262
+ return kernel
263
+
264
+
265
+ def sve_result_new(kernel, epsilon, cutoff=None, lmax=None, n_gauss=None, Twork=None):
266
+ """Create a new SVE result.
267
+
268
+ Note: cutoff parameter is deprecated and ignored (C-API doesn't have it).
269
+ It's kept for backward compatibility but not passed to C-API.
270
+ """
271
+ # Validate epsilon
272
+ if epsilon <= 0:
273
+ raise RuntimeError(
274
+ f"Failed to create SVE result: epsilon must be positive, got {epsilon}")
275
+
276
+ # Note: cutoff parameter was removed from C-API, kept for backward compatibility
277
+ if lmax is None:
278
+ lmax = -1
279
+ if n_gauss is None:
280
+ n_gauss = -1
281
+ if Twork is None:
282
+ Twork = SPIR_TWORK_FLOAT64X2
283
+
284
+ status = c_int()
285
+ # C-API signature: spir_sve_result_new(kernel, epsilon, lmax, n_gauss, Twork, status)
286
+ sve = _lib.spir_sve_result_new(
287
+ kernel, c_double(epsilon), c_int(lmax), c_int(n_gauss), c_int(Twork), byref(status))
288
+ if status.value != COMPUTATION_SUCCESS:
289
+ raise RuntimeError(f"Failed to create SVE result: {status.value}")
290
+ return sve
291
+
292
+
293
+ def sve_result_get_size(sve):
294
+ """Get the size of an SVE result."""
295
+ size = c_int()
296
+ status = _lib.spir_sve_result_get_size(sve, byref(size))
297
+ if status != COMPUTATION_SUCCESS:
298
+ raise RuntimeError(f"Failed to get SVE result size: {status}")
299
+ return size.value
300
+
301
+ def sve_result_truncate(sve, epsilon, max_size):
302
+ """Truncate an SVE result."""
303
+ status = c_int()
304
+ sve = _lib.spir_sve_result_truncate(sve, epsilon, max_size, byref(status))
305
+ if status.value != COMPUTATION_SUCCESS:
306
+ raise RuntimeError(f"Failed to truncate SVE result: {status.value}")
307
+ return sve
308
+
309
+ def sve_result_get_svals(sve):
310
+ """Get the singular values from an SVE result."""
311
+ size = sve_result_get_size(sve)
312
+ svals = np.zeros(size, dtype=np.float64)
313
+ status = _lib.spir_sve_result_get_svals(
314
+ sve, svals.ctypes.data_as(POINTER(c_double)))
315
+ if status != COMPUTATION_SUCCESS:
316
+ raise RuntimeError(f"Failed to get singular values: {status}")
317
+ return svals
318
+
319
+ def basis_new(statistics, beta, omega_max, epsilon, kernel, sve, max_size):
320
+ """Create a new basis."""
321
+ status = c_int()
322
+ basis = _lib.spir_basis_new(
323
+ statistics, beta, omega_max, epsilon, kernel, sve, max_size, byref(status)
324
+ )
325
+ if status.value != COMPUTATION_SUCCESS:
326
+ raise RuntimeError(f"Failed to create basis: {status.value}")
327
+ return basis
328
+
329
+
330
+ def basis_get_size(basis):
331
+ """Get the size of a basis."""
332
+ size = c_int()
333
+ status = _lib.spir_basis_get_size(basis, byref(size))
334
+ if status != COMPUTATION_SUCCESS:
335
+ raise RuntimeError(f"Failed to get basis size: {status}")
336
+ return size.value
337
+
338
+
339
+ def basis_get_svals(basis):
340
+ """Get the singular values of a basis."""
341
+ size = basis_get_size(basis)
342
+ svals = np.zeros(size, dtype=np.float64)
343
+ status = _lib.spir_basis_get_svals(
344
+ basis, svals.ctypes.data_as(POINTER(c_double)))
345
+ if status != COMPUTATION_SUCCESS:
346
+ raise RuntimeError(f"Failed to get singular values: {status}")
347
+ return svals
348
+
349
+
350
+ def basis_get_stats(basis):
351
+ """Get the statistics type of a basis."""
352
+ stats = c_int()
353
+ status = _lib.spir_basis_get_stats(basis, byref(stats))
354
+ if status != COMPUTATION_SUCCESS:
355
+ raise RuntimeError(f"Failed to get basis statistics: {status}")
356
+ return stats.value
357
+
358
+
359
+ def basis_get_u(basis):
360
+ """Get the imaginary-time basis functions."""
361
+ status = c_int()
362
+ funcs = _lib.spir_basis_get_u(basis, byref(status))
363
+ if status.value != COMPUTATION_SUCCESS:
364
+ raise RuntimeError(f"Failed to get u basis functions: {status.value}")
365
+ return funcs
366
+
367
+
368
+ def basis_get_v(basis):
369
+ """Get the real-frequency basis functions."""
370
+ status = c_int()
371
+ funcs = _lib.spir_basis_get_v(basis, byref(status))
372
+ if status.value != COMPUTATION_SUCCESS:
373
+ raise RuntimeError(f"Failed to get v basis functions: {status.value}")
374
+ return funcs
375
+
376
+
377
+ def basis_get_uhat(basis):
378
+ """Get the Matsubara frequency basis functions."""
379
+ status = c_int()
380
+ funcs = _lib.spir_basis_get_uhat(basis, byref(status))
381
+ if status.value != COMPUTATION_SUCCESS:
382
+ raise RuntimeError(
383
+ f"Failed to get uhat basis functions: {status.value}")
384
+ return funcs
385
+
386
+
387
+ def funcs_get_size(funcs):
388
+ """Get the size of a basis function set."""
389
+ size = c_int()
390
+ status = _lib.spir_funcs_get_size(funcs, byref(size))
391
+ if status != COMPUTATION_SUCCESS:
392
+ raise RuntimeError(f"Failed to get function size: {status}")
393
+ return size.value
394
+
395
+ # TODO: Rename funcs_eval_single
396
+
397
+
398
+ def funcs_eval_single_float64(funcs, x):
399
+ """Evaluate basis functions at a single point."""
400
+ # Get number of functions
401
+ size = c_int()
402
+ status = _lib.spir_funcs_get_size(funcs, byref(size))
403
+ if status != COMPUTATION_SUCCESS:
404
+ raise RuntimeError(f"Failed to get function size: {status}")
405
+
406
+ # Prepare output array
407
+ out = np.zeros(size.value, dtype=np.float64)
408
+
409
+ # Evaluate
410
+ status = _lib.spir_funcs_eval(
411
+ funcs, c_double(x),
412
+ out.ctypes.data_as(POINTER(c_double))
413
+ )
414
+ if status != COMPUTATION_SUCCESS:
415
+ raise RuntimeError(f"Failed to evaluate functions: {status}")
416
+
417
+ return out
418
+
419
+ # TODO: Rename to funcs_eval_matsu_single
420
+
421
+
422
+ def funcs_eval_single_complex128(funcs, x):
423
+ """Evaluate basis functions at a single point."""
424
+ # Get number of functions
425
+ size = c_int()
426
+ status = _lib.spir_funcs_get_size(funcs, byref(size))
427
+ if status != COMPUTATION_SUCCESS:
428
+ raise RuntimeError(f"Failed to get function size: {status}")
429
+
430
+ # Prepare output array
431
+ out = np.zeros(size.value, dtype=np.complex128)
432
+
433
+ # Evaluate
434
+ status = _lib.spir_funcs_eval_matsu(
435
+ funcs, c_int64(x),
436
+ out.ctypes.data_as(POINTER(c_double_complex))
437
+ )
438
+ if status != COMPUTATION_SUCCESS:
439
+ raise RuntimeError(f"Failed to evaluate functions: {status}")
440
+
441
+ return out
442
+
443
+
444
+ def funcs_get_n_knots(funcs):
445
+ """Get the number of knots of the underlying piecewise Legendre polynomial."""
446
+ n_knots = c_int()
447
+ status = _lib.spir_funcs_get_n_knots(funcs, byref(n_knots))
448
+ if status != COMPUTATION_SUCCESS:
449
+ raise RuntimeError(f"Failed to get number of knots: {status}")
450
+ return n_knots.value
451
+
452
+
453
+ def funcs_get_knots(funcs):
454
+ """Get the knots of the underlying piecewise Legendre polynomial."""
455
+ n_knots = funcs_get_n_knots(funcs)
456
+ knots = np.zeros(n_knots, dtype=np.float64)
457
+ status = _lib.spir_funcs_get_knots(
458
+ funcs, knots.ctypes.data_as(POINTER(c_double)))
459
+ if status != COMPUTATION_SUCCESS:
460
+ raise RuntimeError(f"Failed to get knots: {status}")
461
+ return knots
462
+
463
+
464
+ def basis_get_default_tau_sampling_points(basis):
465
+ """Get default tau sampling points for a basis."""
466
+ # Get number of points
467
+ n_points = c_int()
468
+ status = _lib.spir_basis_get_n_default_taus(basis, byref(n_points))
469
+ if status != COMPUTATION_SUCCESS:
470
+ raise RuntimeError(
471
+ f"Failed to get number of default tau points: {status}")
472
+
473
+ # Get the points
474
+ points = np.zeros(n_points.value, dtype=np.float64)
475
+ status = _lib.spir_basis_get_default_taus(
476
+ basis, points.ctypes.data_as(POINTER(c_double)))
477
+ if status != COMPUTATION_SUCCESS:
478
+ raise RuntimeError(f"Failed to get default tau points: {status}")
479
+
480
+ return points
481
+
482
+
483
+ def basis_get_default_tau_sampling_points_ext(basis, n_points):
484
+ """Get default tau sampling points for a basis."""
485
+ points = np.zeros(n_points, dtype=np.float64)
486
+ n_points_returned = c_int()
487
+ status = _lib.spir_basis_get_default_taus_ext(
488
+ basis, n_points, points.ctypes.data_as(POINTER(c_double)), byref(n_points_returned))
489
+ if status != COMPUTATION_SUCCESS:
490
+ raise RuntimeError(f"Failed to get default tau points: {status}")
491
+ return points
492
+
493
+
494
+ def basis_get_default_omega_sampling_points(basis):
495
+ """Get default omega (real frequency) sampling points for a basis."""
496
+ # Get number of points
497
+ n_points = c_int()
498
+ status = _lib.spir_basis_get_n_default_ws(basis, byref(n_points))
499
+ if status != COMPUTATION_SUCCESS:
500
+ raise RuntimeError(
501
+ f"Failed to get number of default omega points: {status}")
502
+
503
+ # Get the points
504
+ points = np.zeros(n_points.value, dtype=np.float64)
505
+ status = _lib.spir_basis_get_default_ws(
506
+ basis, points.ctypes.data_as(POINTER(c_double)))
507
+ if status != COMPUTATION_SUCCESS:
508
+ raise RuntimeError(f"Failed to get default omega points: {status}")
509
+
510
+ return points
511
+
512
+
513
+ def basis_get_default_matsubara_sampling_points(basis, positive_only=False):
514
+ """Get default Matsubara sampling points for a basis."""
515
+ # Get number of points
516
+ n_points = c_int()
517
+ status = _lib.spir_basis_get_n_default_matsus(
518
+ basis, c_int(1 if positive_only else 0), byref(n_points))
519
+ if status != COMPUTATION_SUCCESS:
520
+ raise RuntimeError(
521
+ f"Failed to get number of default Matsubara points: {status}")
522
+
523
+ # Get the points
524
+ points = np.zeros(n_points.value, dtype=np.int64)
525
+ status = _lib.spir_basis_get_default_matsus(basis, c_int(
526
+ 1 if positive_only else 0), points.ctypes.data_as(POINTER(c_int64)))
527
+ if status != COMPUTATION_SUCCESS:
528
+ raise RuntimeError(f"Failed to get default Matsubara points: {status}")
529
+
530
+ return points
531
+
532
+
533
+ def basis_get_n_default_matsus_ext(basis, n_points, positive_only):
534
+ """Get the number of default Matsubara sampling points for a basis."""
535
+ n_points_returned = c_int()
536
+ status = _lib.spir_basis_get_n_default_matsus_ext(
537
+ basis, c_int(1 if positive_only else 0), n_points, byref(n_points_returned))
538
+ if status != COMPUTATION_SUCCESS:
539
+ raise RuntimeError(
540
+ f"Failed to get number of default Matsubara points: {status}")
541
+ return n_points_returned.value
542
+
543
+
544
+ def basis_get_default_matsus_ext(basis, positive_only, points):
545
+ n_points = len(points)
546
+ n_points_returned = c_int()
547
+ status = _lib.spir_basis_get_default_matsus_ext(basis, c_int(
548
+ 1 if positive_only else 0), c_int(0), n_points, points.ctypes.data_as(POINTER(c_int64)), byref(n_points_returned))
549
+ if status != COMPUTATION_SUCCESS:
550
+ raise RuntimeError(f"Failed to get default Matsubara points: {status}")
551
+ return points
552
+
553
+
554
+ def tau_sampling_new(basis, sampling_points=None):
555
+ """Create a new tau sampling object."""
556
+ if sampling_points is None:
557
+ sampling_points = basis_get_default_tau_sampling_points(basis)
558
+
559
+ sampling_points = np.asarray(sampling_points, dtype=np.float64)
560
+ n_points = len(sampling_points)
561
+
562
+ status = c_int()
563
+ sampling = _lib.spir_tau_sampling_new(
564
+ basis, n_points,
565
+ sampling_points.ctypes.data_as(POINTER(c_double)),
566
+ byref(status)
567
+ )
568
+ if status.value != COMPUTATION_SUCCESS:
569
+ raise RuntimeError(f"Failed to create tau sampling: {status.value}")
570
+
571
+ return sampling
572
+
573
+
574
+ def _statistics_to_c(statistics):
575
+ """Convert statistics to c type."""
576
+ if statistics == "F":
577
+ return SPIR_STATISTICS_FERMIONIC
578
+ elif statistics == "B":
579
+ return SPIR_STATISTICS_BOSONIC
580
+ else:
581
+ raise ValueError(f"Invalid statistics: {statistics}")
582
+
583
+
584
+ def tau_sampling_new_with_matrix(basis, statistics, sampling_points, matrix):
585
+ """Create a new tau sampling object with a matrix."""
586
+ status = c_int()
587
+ sampling = _lib.spir_tau_sampling_new_with_matrix(
588
+ SPIR_ORDER_ROW_MAJOR,
589
+ _statistics_to_c(statistics),
590
+ basis.size,
591
+ sampling_points.size,
592
+ sampling_points.ctypes.data_as(POINTER(c_double)),
593
+ matrix.ctypes.data_as(POINTER(c_double)),
594
+ byref(status)
595
+ )
596
+ if status.value != COMPUTATION_SUCCESS:
597
+ raise RuntimeError(f"Failed to create tau sampling: {status.value}")
598
+
599
+ return sampling
600
+
601
+
602
+ def matsubara_sampling_new(basis, positive_only=False, sampling_points=None):
603
+ """Create a new Matsubara sampling object."""
604
+ if sampling_points is None:
605
+ sampling_points = basis_get_default_matsubara_sampling_points(
606
+ basis, positive_only)
607
+
608
+ sampling_points = np.asarray(sampling_points, dtype=np.int64)
609
+ n_points = len(sampling_points)
610
+
611
+ status = c_int()
612
+ sampling = _lib.spir_matsu_sampling_new(
613
+ basis, c_bool(positive_only), n_points,
614
+ sampling_points.ctypes.data_as(POINTER(c_int64)),
615
+ byref(status)
616
+ )
617
+ if status.value != COMPUTATION_SUCCESS:
618
+ raise RuntimeError(
619
+ f"Failed to create Matsubara sampling: {status.value}")
620
+
621
+ return sampling
622
+
623
+
624
+ def matsubara_sampling_new_with_matrix(statistics, basis_size, positive_only, sampling_points, matrix):
625
+ """Create a new Matsubara sampling object with a matrix."""
626
+ status = c_int()
627
+ sampling = _lib.spir_matsu_sampling_new_with_matrix(
628
+ SPIR_ORDER_ROW_MAJOR, # order
629
+ _statistics_to_c(statistics), # statistics
630
+ c_int(basis_size), # basis_size
631
+ c_bool(positive_only), # positive_only
632
+ c_int(len(sampling_points)), # num_points
633
+ sampling_points.ctypes.data_as(POINTER(c_int64)), # points
634
+ matrix.ctypes.data_as(POINTER(c_double_complex)), # matrix
635
+ byref(status) # status
636
+ )
637
+ if status.value != COMPUTATION_SUCCESS:
638
+ raise RuntimeError(
639
+ f"Failed to create Matsubara sampling: {status.value}")
640
+
641
+ return sampling