ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,358 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module Solver
7
+ # cuSOLVER Dense (Dn) library FFI bindings
8
+ # Provides LU decomposition, SVD, and eigenvalue solvers
9
+ module CuSolverBindings
10
+ extend FFI::Library
11
+
12
+ # cuSOLVER status codes
13
+ CUSOLVER_STATUS_SUCCESS = 0
14
+ CUSOLVER_STATUS_NOT_INITIALIZED = 1
15
+ CUSOLVER_STATUS_ALLOC_FAILED = 2
16
+ CUSOLVER_STATUS_INVALID_VALUE = 3
17
+ CUSOLVER_STATUS_ARCH_MISMATCH = 4
18
+ CUSOLVER_STATUS_MAPPING_ERROR = 5
19
+ CUSOLVER_STATUS_EXECUTION_FAILED = 6
20
+ CUSOLVER_STATUS_INTERNAL_ERROR = 7
21
+ CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED = 8
22
+ CUSOLVER_STATUS_NOT_SUPPORTED = 9
23
+ CUSOLVER_STATUS_ZERO_PIVOT = 10
24
+ CUSOLVER_STATUS_INVALID_LICENSE = 11
25
+ CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED = 12
26
+ CUSOLVER_STATUS_IRS_PARAMS_INVALID = 13
27
+ CUSOLVER_STATUS_IRS_PARAMS_INVALID_PREC = 14
28
+ CUSOLVER_STATUS_IRS_PARAMS_INVALID_REFINE = 15
29
+ CUSOLVER_STATUS_IRS_PARAMS_INVALID_MAXITER = 16
30
+ CUSOLVER_STATUS_IRS_INTERNAL_ERROR = 20
31
+ CUSOLVER_STATUS_IRS_NOT_SUPPORTED = 21
32
+ CUSOLVER_STATUS_IRS_OUT_OF_RANGE = 22
33
+ CUSOLVER_STATUS_IRS_NRHS_NOT_SUPPORTED_FOR_REFINE_GMRES = 23
34
+ CUSOLVER_STATUS_IRS_INFOS_NOT_INITIALIZED = 25
35
+
36
+ # cuSOLVER EigMode (for eigenvalue/eigenvector computation)
37
+ CUSOLVER_EIG_MODE_NOVECTOR = 0 # Compute eigenvalues only
38
+ CUSOLVER_EIG_MODE_VECTOR = 1 # Compute eigenvalues and eigenvectors
39
+
40
+ # cuSOLVER EigType (for generalized eigenvalue problem)
41
+ CUSOLVER_EIG_TYPE_1 = 1 # A*x = lambda*B*x
42
+ CUSOLVER_EIG_TYPE_2 = 2 # A*B*x = lambda*x
43
+ CUSOLVER_EIG_TYPE_3 = 3 # B*A*x = lambda*x
44
+
45
+ # cuSOLVER EigRange (eigenvalue range selection)
46
+ CUSOLVER_EIG_RANGE_ALL = 0 # All eigenvalues
47
+ CUSOLVER_EIG_RANGE_V = 1 # Eigenvalues in half-open interval (vl, vu]
48
+ CUSOLVER_EIG_RANGE_I = 2 # il-th through iu-th eigenvalues
49
+
50
+ # cuBLAS fill mode (used by cuSOLVER)
51
+ CUBLAS_FILL_MODE_LOWER = 0
52
+ CUBLAS_FILL_MODE_UPPER = 1
53
+ CUBLAS_FILL_MODE_FULL = 2
54
+
55
+ # cuBLAS operation types
56
+ CUBLAS_OP_N = 0 # No transpose
57
+ CUBLAS_OP_T = 1 # Transpose
58
+ CUBLAS_OP_C = 2 # Conjugate transpose
59
+
60
+ # Status code descriptions
61
+ STATUS_DESCRIPTIONS = {
62
+ CUSOLVER_STATUS_SUCCESS => "Success",
63
+ CUSOLVER_STATUS_NOT_INITIALIZED => "Library not initialized",
64
+ CUSOLVER_STATUS_ALLOC_FAILED => "Resource allocation failed",
65
+ CUSOLVER_STATUS_INVALID_VALUE => "Invalid parameter value",
66
+ CUSOLVER_STATUS_ARCH_MISMATCH => "Architecture mismatch",
67
+ CUSOLVER_STATUS_MAPPING_ERROR => "Mapping error",
68
+ CUSOLVER_STATUS_EXECUTION_FAILED => "Execution failed",
69
+ CUSOLVER_STATUS_INTERNAL_ERROR => "Internal error",
70
+ CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED => "Matrix type not supported",
71
+ CUSOLVER_STATUS_NOT_SUPPORTED => "Operation not supported",
72
+ CUSOLVER_STATUS_ZERO_PIVOT => "Zero pivot encountered",
73
+ CUSOLVER_STATUS_INVALID_LICENSE => "Invalid license"
74
+ }.freeze
75
+
76
+ @loaded = false
77
+ @handle = nil
78
+ @mutex = Mutex.new
79
+
80
+ class << self
81
+ # @return [FFI::Pointer, nil] cuSOLVER handle
82
+ attr_accessor :handle
83
+
84
+ # Check if cuSOLVER is loaded
85
+ # @return [Boolean]
86
+ def loaded?
87
+ @loaded
88
+ end
89
+
90
+ # Ensure cuSOLVER is loaded and initialized
91
+ # @return [void]
92
+ # @raise [CuSolverError] If initialization fails
93
+ def ensure_loaded!
94
+ @mutex.synchronize do
95
+ return if @loaded
96
+
97
+ CUDA::LibraryLoader.load_library(:cusolver)
98
+
99
+ cuda_bin = Ignis.configuration.cuda_bin_path
100
+ dll_path = if cuda_bin
101
+ Dir.glob(File.join(cuda_bin, "cusolver64_*.dll")).max
102
+ else
103
+ "cusolver64_12"
104
+ end
105
+
106
+ raise LibraryNotFoundError, "cusolver" unless dll_path
107
+
108
+ ffi_lib dll_path
109
+ attach_cusolver_functions!
110
+ initialize_cusolver!
111
+
112
+ @loaded = true
113
+ Ignis.logger.info("cuSOLVER initialized successfully")
114
+ end
115
+ end
116
+
117
+ # Get or create cuSOLVER handle
118
+ # @return [FFI::Pointer]
119
+ def get_handle
120
+ ensure_loaded!
121
+ @handle
122
+ end
123
+
124
+ # Set the CUDA stream for cuSOLVER operations
125
+ # @param stream [FFI::Pointer] CUDA stream
126
+ # @return [void]
127
+ def set_stream(stream)
128
+ ensure_loaded!
129
+ status = cusolverDnSetStream(@handle, stream)
130
+ check_status!(status, "cusolverDnSetStream")
131
+ end
132
+
133
+ # Finalize cuSOLVER and release resources
134
+ # @return [void]
135
+ def finalize!
136
+ @mutex.synchronize do
137
+ return unless @handle
138
+
139
+ cusolverDnDestroy(@handle)
140
+ @handle = nil
141
+ @loaded = false
142
+ Ignis.logger.info("cuSOLVER finalized")
143
+ end
144
+ end
145
+
146
+ # Check cuSOLVER status and raise error if not success
147
+ # @param status [Integer] cuSOLVER status code
148
+ # @param context [String] Context for error message
149
+ # @return [void]
150
+ # @raise [CuSolverError] If status indicates an error
151
+ def check_status!(status, context = "cuSOLVER operation")
152
+ return if status == CUSOLVER_STATUS_SUCCESS
153
+
154
+ description = STATUS_DESCRIPTIONS[status] || "Unknown error"
155
+ raise CuSolverError.new("#{context} failed: #{description}", cusolver_code: status)
156
+ end
157
+
158
+ private
159
+
160
+ # Attach all cuSOLVER FFI functions
161
+ # @return [void]
162
+ def attach_cusolver_functions!
163
+ # Handle management
164
+ attach_function :cusolverDnCreate, [:pointer], :int
165
+ attach_function :cusolverDnDestroy, [:pointer], :int
166
+ attach_function :cusolverDnSetStream, [:pointer, :pointer], :int
167
+ attach_function :cusolverDnGetStream, [:pointer, :pointer], :int
168
+
169
+ # LU factorization - buffer size queries
170
+ attach_function :cusolverDnSgetrf_bufferSize, [
171
+ :pointer, :int, :int, :pointer, :int, :pointer
172
+ ], :int
173
+ attach_function :cusolverDnDgetrf_bufferSize, [
174
+ :pointer, :int, :int, :pointer, :int, :pointer
175
+ ], :int
176
+ attach_function :cusolverDnCgetrf_bufferSize, [
177
+ :pointer, :int, :int, :pointer, :int, :pointer
178
+ ], :int
179
+ attach_function :cusolverDnZgetrf_bufferSize, [
180
+ :pointer, :int, :int, :pointer, :int, :pointer
181
+ ], :int
182
+
183
+ # LU factorization
184
+ attach_function :cusolverDnSgetrf, [
185
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
186
+ ], :int
187
+ attach_function :cusolverDnDgetrf, [
188
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
189
+ ], :int
190
+ attach_function :cusolverDnCgetrf, [
191
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
192
+ ], :int
193
+ attach_function :cusolverDnZgetrf, [
194
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
195
+ ], :int
196
+
197
+ # LU solve (getrs)
198
+ attach_function :cusolverDnSgetrs, [
199
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
200
+ ], :int
201
+ attach_function :cusolverDnDgetrs, [
202
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
203
+ ], :int
204
+ attach_function :cusolverDnCgetrs, [
205
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
206
+ ], :int
207
+ attach_function :cusolverDnZgetrs, [
208
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
209
+ ], :int
210
+
211
+ # SVD buffer size queries
212
+ attach_function :cusolverDnSgesvd_bufferSize, [
213
+ :pointer, :int, :int, :pointer
214
+ ], :int
215
+ attach_function :cusolverDnDgesvd_bufferSize, [
216
+ :pointer, :int, :int, :pointer
217
+ ], :int
218
+ attach_function :cusolverDnCgesvd_bufferSize, [
219
+ :pointer, :int, :int, :pointer
220
+ ], :int
221
+ attach_function :cusolverDnZgesvd_bufferSize, [
222
+ :pointer, :int, :int, :pointer
223
+ ], :int
224
+
225
+ # SVD computation
226
+ attach_function :cusolverDnSgesvd, [
227
+ :pointer, :char, :char, :int, :int, :pointer, :int,
228
+ :pointer, :pointer, :int, :pointer, :int,
229
+ :pointer, :int, :pointer, :pointer
230
+ ], :int
231
+ attach_function :cusolverDnDgesvd, [
232
+ :pointer, :char, :char, :int, :int, :pointer, :int,
233
+ :pointer, :pointer, :int, :pointer, :int,
234
+ :pointer, :int, :pointer, :pointer
235
+ ], :int
236
+ attach_function :cusolverDnCgesvd, [
237
+ :pointer, :char, :char, :int, :int, :pointer, :int,
238
+ :pointer, :pointer, :int, :pointer, :int,
239
+ :pointer, :int, :pointer, :pointer
240
+ ], :int
241
+ attach_function :cusolverDnZgesvd, [
242
+ :pointer, :char, :char, :int, :int, :pointer, :int,
243
+ :pointer, :pointer, :int, :pointer, :int,
244
+ :pointer, :int, :pointer, :pointer
245
+ ], :int
246
+
247
+ # Symmetric eigenvalue solver buffer size (syevd)
248
+ attach_function :cusolverDnSsyevd_bufferSize, [
249
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
250
+ ], :int
251
+ attach_function :cusolverDnDsyevd_bufferSize, [
252
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
253
+ ], :int
254
+ attach_function :cusolverDnCheevd_bufferSize, [
255
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
256
+ ], :int
257
+ attach_function :cusolverDnZheevd_bufferSize, [
258
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
259
+ ], :int
260
+
261
+ # Symmetric eigenvalue solver (syevd/heevd)
262
+ attach_function :cusolverDnSsyevd, [
263
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
264
+ ], :int
265
+ attach_function :cusolverDnDsyevd, [
266
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
267
+ ], :int
268
+ attach_function :cusolverDnCheevd, [
269
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
270
+ ], :int
271
+ attach_function :cusolverDnZheevd, [
272
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
273
+ ], :int
274
+
275
+ # Cholesky factorization buffer size
276
+ attach_function :cusolverDnSpotrf_bufferSize, [
277
+ :pointer, :int, :int, :pointer, :int, :pointer
278
+ ], :int
279
+ attach_function :cusolverDnDpotrf_bufferSize, [
280
+ :pointer, :int, :int, :pointer, :int, :pointer
281
+ ], :int
282
+ attach_function :cusolverDnCpotrf_bufferSize, [
283
+ :pointer, :int, :int, :pointer, :int, :pointer
284
+ ], :int
285
+ attach_function :cusolverDnZpotrf_bufferSize, [
286
+ :pointer, :int, :int, :pointer, :int, :pointer
287
+ ], :int
288
+
289
+ # Cholesky factorization
290
+ attach_function :cusolverDnSpotrf, [
291
+ :pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
292
+ ], :int
293
+ attach_function :cusolverDnDpotrf, [
294
+ :pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
295
+ ], :int
296
+ attach_function :cusolverDnCpotrf, [
297
+ :pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
298
+ ], :int
299
+ attach_function :cusolverDnZpotrf, [
300
+ :pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
301
+ ], :int
302
+
303
+ # Cholesky solve
304
+ attach_function :cusolverDnSpotrs, [
305
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
306
+ ], :int
307
+ attach_function :cusolverDnDpotrs, [
308
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
309
+ ], :int
310
+ attach_function :cusolverDnCpotrs, [
311
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
312
+ ], :int
313
+ attach_function :cusolverDnZpotrs, [
314
+ :pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
315
+ ], :int
316
+
317
+ # QR factorization buffer size
318
+ attach_function :cusolverDnSgeqrf_bufferSize, [
319
+ :pointer, :int, :int, :pointer, :int, :pointer
320
+ ], :int
321
+ attach_function :cusolverDnDgeqrf_bufferSize, [
322
+ :pointer, :int, :int, :pointer, :int, :pointer
323
+ ], :int
324
+ attach_function :cusolverDnCgeqrf_bufferSize, [
325
+ :pointer, :int, :int, :pointer, :int, :pointer
326
+ ], :int
327
+ attach_function :cusolverDnZgeqrf_bufferSize, [
328
+ :pointer, :int, :int, :pointer, :int, :pointer
329
+ ], :int
330
+
331
+ # QR factorization
332
+ attach_function :cusolverDnSgeqrf, [
333
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
334
+ ], :int
335
+ attach_function :cusolverDnDgeqrf, [
336
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
337
+ ], :int
338
+ attach_function :cusolverDnCgeqrf, [
339
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
340
+ ], :int
341
+ attach_function :cusolverDnZgeqrf, [
342
+ :pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
343
+ ], :int
344
+ end
345
+
346
+ # Initialize cuSOLVER handle
347
+ # @return [void]
348
+ def initialize_cusolver!
349
+ handle_ptr = FFI::MemoryPointer.new(:pointer)
350
+ status = cusolverDnCreate(handle_ptr)
351
+ check_status!(status, "cusolverDnCreate")
352
+
353
+ @handle = handle_ptr.read_pointer
354
+ end
355
+ end
356
+ end
357
+ end
358
+ end
@@ -0,0 +1,226 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "cusolver_bindings"
4
+
5
+ module Ignis
6
+ module Solver
7
+ # Eigenvalue and Eigenvector computation using cuSOLVER
8
+ # Supports symmetric/Hermitian matrices (syevd/heevd) and general matrices (geev)
9
+ module Eigen
10
+ class << self
11
+ # Compute eigenvalues and eigenvectors of a symmetric/Hermitian matrix
12
+ # Uses divide-and-conquer algorithm (syevd/heevd)
13
+ # @param matrix [NvArray] Symmetric (real) or Hermitian (complex) matrix (n x n)
14
+ # @param eigenvectors [Boolean] If true, compute eigenvectors
15
+ # @param uplo [Symbol] :lower or :upper indicating which triangle contains the matrix
16
+ # @return [Hash] Contains :eigenvalues and optionally :eigenvectors
17
+ # @raise [CuSolverError] If computation fails
18
+ def eigh(matrix, eigenvectors: true, uplo: :lower)
19
+ CuSolverBindings.ensure_loaded!
20
+
21
+ validate_symmetric_matrix!(matrix)
22
+ n = matrix.shape[0]
23
+ lda = n
24
+
25
+ # Job mode
26
+ jobz = eigenvectors ? CuSolverBindings::CUSOLVER_EIG_MODE_VECTOR : CuSolverBindings::CUSOLVER_EIG_MODE_NOVECTOR
27
+ fill_mode = uplo == :lower ? CuSolverBindings::CUBLAS_FILL_MODE_LOWER : CuSolverBindings::CUBLAS_FILL_MODE_UPPER
28
+
29
+ # Allocate eigenvalues array (always real)
30
+ real_dtype = real_dtype_for(matrix.dtype)
31
+ eigenvalues = CUDA::Memory.new(n * dtype_size(real_dtype))
32
+
33
+ # Copy matrix to work array (will contain eigenvectors on output if requested)
34
+ work_matrix = matrix.dup
35
+
36
+ # Get workspace size
37
+ lwork_ptr = FFI::MemoryPointer.new(:int)
38
+ get_syevd_buffer_size(jobz, fill_mode, n, work_matrix.device_ptr, lda, eigenvalues, lwork_ptr, matrix.dtype)
39
+ lwork = lwork_ptr.read_int
40
+
41
+ # Allocate workspace and info
42
+ workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype))
43
+ info = CUDA::Memory.new(4)
44
+
45
+ # Perform eigenvalue computation
46
+ perform_syevd(jobz, fill_mode, n, work_matrix.device_ptr, lda, eigenvalues, workspace, lwork, info, matrix.dtype)
47
+
48
+ # Check info
49
+ info_value = read_device_int(info)
50
+ if info_value < 0
51
+ raise CuSolverError.new("Eigenvalue computation: parameter #{-info_value} had an illegal value",
52
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
53
+ elsif info_value > 0
54
+ raise CuSolverError.new("Eigenvalue computation: #{info_value} off-diagonal elements did not converge to zero",
55
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_EXECUTION_FAILED)
56
+ end
57
+
58
+ CUDA::RuntimeAPI.cudaDeviceSynchronize
59
+
60
+ result = {
61
+ eigenvalues: NvArray.from_device_ptr(eigenvalues, shape: [n], dtype: real_dtype, take_ownership: true)
62
+ }
63
+
64
+ if eigenvectors
65
+ result[:eigenvectors] = work_matrix
66
+ end
67
+
68
+ result
69
+ ensure
70
+ workspace&.free! if defined?(workspace) && workspace
71
+ info&.free! if defined?(info) && info
72
+ end
73
+
74
+ # Compute eigenvalues only of a symmetric/Hermitian matrix (faster)
75
+ # @param matrix [NvArray] Symmetric/Hermitian matrix
76
+ # @return [NvArray] Eigenvalues in ascending order
77
+ def eigvalsh(matrix)
78
+ result = eigh(matrix, eigenvectors: false)
79
+ result[:eigenvalues]
80
+ end
81
+
82
+ # Compute eigenvalues and eigenvectors of a general (non-symmetric) matrix
83
+ # Note: This is more expensive than eigh for symmetric matrices
84
+ # @param matrix [NvArray] General square matrix (n x n)
85
+ # @return [Hash] Contains :eigenvalues (may be complex), :eigenvectors_right
86
+ def eig(matrix)
87
+ CuSolverBindings.ensure_loaded!
88
+
89
+ validate_square_matrix!(matrix)
90
+ n = matrix.shape[0]
91
+
92
+ # For general matrices, eigenvalues may be complex
93
+ # Even for real input, we may get complex eigenvalues
94
+ # For now, we'll use Schur decomposition approximation
95
+
96
+ # Copy matrix
97
+ work_matrix = matrix.dup
98
+
99
+ # For real matrices, eigenvalues are returned as real/imaginary pairs
100
+ # For complex matrices, eigenvalues are complex directly
101
+
102
+ if complex_dtype?(matrix.dtype)
103
+ compute_complex_eig(work_matrix, n)
104
+ else
105
+ compute_real_eig(work_matrix, n)
106
+ end
107
+ end
108
+
109
+ # Check if a matrix is approximately symmetric
110
+ # @param matrix [NvArray] Matrix to check
111
+ # @param tol [Float] Tolerance for symmetry check
112
+ # @return [Boolean]
113
+ def symmetric?(matrix, tol: 1e-10)
114
+ return false unless matrix.shape[0] == matrix.shape[1]
115
+
116
+ # For a proper check, we'd compare A with A^T
117
+ # This is a simplified placeholder - full implementation would
118
+ # compute max(abs(A - A^T))
119
+ true # Placeholder - assumes caller knows matrix is symmetric
120
+ end
121
+
122
+ private
123
+
124
+ def validate_symmetric_matrix!(matrix)
125
+ raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
126
+ raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
127
+ raise ArgumentError, "Matrix must be square" unless matrix.shape[0] == matrix.shape[1]
128
+ raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
129
+ end
130
+
131
+ def validate_square_matrix!(matrix)
132
+ raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
133
+ raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
134
+ raise ArgumentError, "Matrix must be square" unless matrix.shape[0] == matrix.shape[1]
135
+ raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
136
+ end
137
+
138
+ def dtype_size(dtype)
139
+ case dtype
140
+ when :float32 then 4
141
+ when :float64 then 8
142
+ when :complex64 then 8
143
+ when :complex128 then 16
144
+ else 4
145
+ end
146
+ end
147
+
148
+ def real_dtype_for(dtype)
149
+ case dtype
150
+ when :complex64 then :float32
151
+ when :complex128 then :float64
152
+ else dtype
153
+ end
154
+ end
155
+
156
+ def complex_dtype?(dtype)
157
+ %i[complex64 complex128].include?(dtype)
158
+ end
159
+
160
+ def read_device_int(ptr)
161
+ host_ptr = FFI::MemoryPointer.new(:int)
162
+ CUDA::RuntimeAPI.cudaMemcpy(host_ptr, ptr, 4, CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_HOST)
163
+ host_ptr.read_int
164
+ end
165
+
166
+ def get_syevd_buffer_size(jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr, dtype)
167
+ handle = CuSolverBindings.get_handle
168
+
169
+ status = case dtype
170
+ when :float32
171
+ CuSolverBindings.cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
172
+ when :float64
173
+ CuSolverBindings.cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
174
+ when :complex64
175
+ CuSolverBindings.cusolverDnCheevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
176
+ when :complex128
177
+ CuSolverBindings.cusolverDnZheevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
178
+ else
179
+ CuSolverBindings.cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
180
+ end
181
+
182
+ CuSolverBindings.check_status!(status, "cusolverDnXsyevd_bufferSize")
183
+ end
184
+
185
+ def perform_syevd(jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info, dtype)
186
+ handle = CuSolverBindings.get_handle
187
+
188
+ status = case dtype
189
+ when :float32
190
+ CuSolverBindings.cusolverDnSsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
191
+ when :float64
192
+ CuSolverBindings.cusolverDnDsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
193
+ when :complex64
194
+ CuSolverBindings.cusolverDnCheevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
195
+ when :complex128
196
+ CuSolverBindings.cusolverDnZheevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
197
+ else
198
+ CuSolverBindings.cusolverDnSsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
199
+ end
200
+
201
+ CuSolverBindings.check_status!(status, "cusolverDnXsyevd")
202
+ end
203
+
204
+ # General (non-symmetric) real eigensolver.
205
+ #
206
+ # NOT IMPLEMENTED: cuSOLVER's general eigensolver (geev) is not bound. The
207
+ # previous code silently ran the SYMMETRIC solver (syevd) on the raw matrix
208
+ # and only logged a warning, returning the eigenpairs of (A+Aᵀ)/2 — wrong
209
+ # (and possibly real-only) results for any non-symmetric A. Fail loudly
210
+ # instead; use #eigh for genuinely symmetric/Hermitian matrices.
211
+ def compute_real_eig(_matrix, _n)
212
+ raise NotImplementedError,
213
+ "General (non-symmetric) eigenvalue solver is not implemented (geev not bound). " \
214
+ "Use eigh/eigvalsh for symmetric/Hermitian matrices."
215
+ end
216
+
217
+ # General complex eigensolver — not implemented (see compute_real_eig).
218
+ def compute_complex_eig(_matrix, _n)
219
+ raise NotImplementedError,
220
+ "General complex eigenvalue solver is not implemented (geev not bound). " \
221
+ "Use eigh for Hermitian matrices."
222
+ end
223
+ end
224
+ end
225
+ end
226
+ end