ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,266 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "cusolver_bindings"
4
+
5
+ module Ignis
6
+ module Solver
7
+ # Singular Value Decomposition (SVD) operations using cuSOLVER
8
+ # Computes A = U * Σ * V^T decomposition
9
+ module SVD
10
+ # Job types for SVD computation
11
+ JOB_ALL = "A".ord # Compute all m columns of U / all n rows of V^T
12
+ JOB_SLIM = "S".ord # Compute min(m,n) columns of U / rows of V^T
13
+ JOB_OVERWRITE = "O".ord # Overwrite A with U or V^T
14
+ JOB_NONE = "N".ord # Do not compute U or V^T
15
+
16
+ class << self
17
+ # Compute SVD of a matrix
18
+ # @param matrix [NvArray] Input matrix (m x n)
19
+ # @param full_matrices [Boolean] If true, compute full U (m x m) and Vt (n x n)
20
+ # @param compute_uv [Boolean] If true, compute U and Vt matrices
21
+ # @return [Hash] Contains :u, :s (singular values), :vt
22
+ # @raise [CuSolverError] If computation fails
23
+ def gesvd(matrix, full_matrices: false, compute_uv: true)
24
+ CuSolverBindings.ensure_loaded!
25
+
26
+ validate_matrix!(matrix)
27
+ m, n = matrix.shape
28
+ lda = m
29
+ min_mn = [m, n].min
30
+
31
+ # Determine job type
32
+ if compute_uv
33
+ jobu = full_matrices ? JOB_ALL : JOB_SLIM
34
+ jobvt = full_matrices ? JOB_ALL : JOB_SLIM
35
+ else
36
+ jobu = JOB_NONE
37
+ jobvt = JOB_NONE
38
+ end
39
+
40
+ # Allocate output arrays
41
+ s = allocate_singular_values(min_mn, matrix.dtype)
42
+ u = allocate_u_matrix(m, min_mn, full_matrices, compute_uv, matrix.dtype)
43
+ vt = allocate_vt_matrix(n, min_mn, full_matrices, compute_uv, matrix.dtype)
44
+
45
+ # Dimensions
46
+ ldu = compute_uv ? (full_matrices ? m : m) : 1
47
+ ldvt = compute_uv ? (full_matrices ? n : min_mn) : 1
48
+
49
+ # Get workspace size
50
+ lwork_ptr = FFI::MemoryPointer.new(:int)
51
+ get_gesvd_buffer_size(m, n, lwork_ptr, matrix.dtype)
52
+ lwork = lwork_ptr.read_int
53
+
54
+ # Allocate workspace
55
+ workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype))
56
+ rwork = CUDA::Memory.new(min_mn * dtype_size(real_dtype(matrix.dtype)))
57
+ info = CUDA::Memory.new(4)
58
+
59
+ # Copy matrix to avoid overwriting original
60
+ work_matrix = matrix.dup
61
+
62
+ # Perform SVD
63
+ perform_gesvd(
64
+ jobu, jobvt, m, n,
65
+ work_matrix.device_ptr, lda,
66
+ s, u, ldu, vt, ldvt,
67
+ workspace, lwork, rwork, info,
68
+ matrix.dtype
69
+ )
70
+
71
+ # Check info
72
+ info_value = read_device_int(info)
73
+ if info_value < 0
74
+ raise CuSolverError.new("SVD: parameter #{-info_value} had an illegal value",
75
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
76
+ elsif info_value > 0
77
+ raise CuSolverError.new("SVD: #{info_value} superdiagonals did not converge",
78
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_EXECUTION_FAILED)
79
+ end
80
+
81
+ CUDA::RuntimeAPI.cudaDeviceSynchronize
82
+
83
+ # Create result arrays
84
+ result = {
85
+ s: create_singular_values_array(s, min_mn, matrix.dtype)
86
+ }
87
+
88
+ if compute_uv
89
+ result[:u] = create_u_array(u, m, min_mn, full_matrices, matrix.dtype)
90
+ result[:vt] = create_vt_array(vt, n, min_mn, full_matrices, matrix.dtype)
91
+ end
92
+
93
+ result
94
+ ensure
95
+ workspace&.free! if defined?(workspace) && workspace
96
+ rwork&.free! if defined?(rwork) && rwork
97
+ info&.free! if defined?(info) && info
98
+ end
99
+
100
+ # Compute only singular values (faster than full SVD)
101
+ # @param matrix [NvArray] Input matrix (m x n)
102
+ # @return [NvArray] Singular values in descending order
103
+ def singular_values(matrix)
104
+ result = gesvd(matrix, compute_uv: false)
105
+ result[:s]
106
+ end
107
+
108
+ # Compute matrix rank using SVD
109
+ # @param matrix [NvArray] Input matrix
110
+ # @param tol [Float, nil] Tolerance (defaults to max(m,n) * eps * largest_singular_value)
111
+ # @return [Integer] Numerical rank
112
+ def rank(matrix, tol: nil)
113
+ s = singular_values(matrix)
114
+ s_host = s.to_a
115
+
116
+ return 0 if s_host.empty?
117
+
118
+ max_sv = s_host.first
119
+ tol ||= [matrix.shape[0], matrix.shape[1]].max * Float::EPSILON * max_sv
120
+
121
+ s_host.count { |sv| sv > tol }
122
+ end
123
+
124
+ # Compute condition number using SVD
125
+ # @param matrix [NvArray] Input matrix
126
+ # @return [Float] Condition number (ratio of largest to smallest singular value)
127
+ def cond(matrix)
128
+ s = singular_values(matrix)
129
+ s_host = s.to_a
130
+
131
+ return Float::INFINITY if s_host.empty? || s_host.last.zero?
132
+
133
+ s_host.first / s_host.last
134
+ end
135
+
136
+ private
137
+
138
+ def validate_matrix!(matrix)
139
+ raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
140
+ raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
141
+ raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
142
+ end
143
+
144
+ def dtype_size(dtype)
145
+ case dtype
146
+ when :float32 then 4
147
+ when :float64 then 8
148
+ when :complex64 then 8
149
+ when :complex128 then 16
150
+ else 4
151
+ end
152
+ end
153
+
154
+ def real_dtype(dtype)
155
+ case dtype
156
+ when :complex64 then :float32
157
+ when :complex128 then :float64
158
+ else dtype
159
+ end
160
+ end
161
+
162
+ def read_device_int(ptr)
163
+ host_ptr = FFI::MemoryPointer.new(:int)
164
+ CUDA::RuntimeAPI.cudaMemcpy(host_ptr, ptr, 4, CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_HOST)
165
+ host_ptr.read_int
166
+ end
167
+
168
+ def allocate_singular_values(count, dtype)
169
+ size = count * dtype_size(real_dtype(dtype))
170
+ CUDA::Memory.new(size)
171
+ end
172
+
173
+ def allocate_u_matrix(m, min_mn, full_matrices, compute_uv, dtype)
174
+ return nil unless compute_uv
175
+
176
+ cols = full_matrices ? m : min_mn
177
+ size = m * cols * dtype_size(dtype)
178
+ CUDA::Memory.new(size)
179
+ end
180
+
181
+ def allocate_vt_matrix(n, min_mn, full_matrices, compute_uv, dtype)
182
+ return nil unless compute_uv
183
+
184
+ rows = full_matrices ? n : min_mn
185
+ size = rows * n * dtype_size(dtype)
186
+ CUDA::Memory.new(size)
187
+ end
188
+
189
+ def create_singular_values_array(ptr, count, dtype)
190
+ real = real_dtype(dtype)
191
+ NvArray.from_device_ptr(ptr, shape: [count], dtype: real, take_ownership: true)
192
+ end
193
+
194
+ def create_u_array(ptr, m, min_mn, full_matrices, dtype)
195
+ cols = full_matrices ? m : min_mn
196
+ NvArray.from_device_ptr(ptr, shape: [m, cols], dtype: dtype, take_ownership: true)
197
+ end
198
+
199
+ def create_vt_array(ptr, n, min_mn, full_matrices, dtype)
200
+ rows = full_matrices ? n : min_mn
201
+ NvArray.from_device_ptr(ptr, shape: [rows, n], dtype: dtype, take_ownership: true)
202
+ end
203
+
204
+ def get_gesvd_buffer_size(m, n, lwork_ptr, dtype)
205
+ handle = CuSolverBindings.get_handle
206
+
207
+ status = case dtype
208
+ when :float32
209
+ CuSolverBindings.cusolverDnSgesvd_bufferSize(handle, m, n, lwork_ptr)
210
+ when :float64
211
+ CuSolverBindings.cusolverDnDgesvd_bufferSize(handle, m, n, lwork_ptr)
212
+ when :complex64
213
+ CuSolverBindings.cusolverDnCgesvd_bufferSize(handle, m, n, lwork_ptr)
214
+ when :complex128
215
+ CuSolverBindings.cusolverDnZgesvd_bufferSize(handle, m, n, lwork_ptr)
216
+ else
217
+ CuSolverBindings.cusolverDnSgesvd_bufferSize(handle, m, n, lwork_ptr)
218
+ end
219
+
220
+ CuSolverBindings.check_status!(status, "cusolverDnXgesvd_bufferSize")
221
+ end
222
+
223
+ def perform_gesvd(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, info, dtype)
224
+ handle = CuSolverBindings.get_handle
225
+ u_ptr = u || FFI::Pointer::NULL
226
+ vt_ptr = vt || FFI::Pointer::NULL
227
+
228
+ status = case dtype
229
+ when :float32
230
+ CuSolverBindings.cusolverDnSgesvd(
231
+ handle, jobu, jobvt, m, n, a, lda,
232
+ s, u_ptr, ldu, vt_ptr, ldvt,
233
+ work, lwork, rwork, info
234
+ )
235
+ when :float64
236
+ CuSolverBindings.cusolverDnDgesvd(
237
+ handle, jobu, jobvt, m, n, a, lda,
238
+ s, u_ptr, ldu, vt_ptr, ldvt,
239
+ work, lwork, rwork, info
240
+ )
241
+ when :complex64
242
+ CuSolverBindings.cusolverDnCgesvd(
243
+ handle, jobu, jobvt, m, n, a, lda,
244
+ s, u_ptr, ldu, vt_ptr, ldvt,
245
+ work, lwork, rwork, info
246
+ )
247
+ when :complex128
248
+ CuSolverBindings.cusolverDnZgesvd(
249
+ handle, jobu, jobvt, m, n, a, lda,
250
+ s, u_ptr, ldu, vt_ptr, ldvt,
251
+ work, lwork, rwork, info
252
+ )
253
+ else
254
+ CuSolverBindings.cusolverDnSgesvd(
255
+ handle, jobu, jobvt, m, n, a, lda,
256
+ s, u_ptr, ldu, vt_ptr, ldvt,
257
+ work, lwork, rwork, info
258
+ )
259
+ end
260
+
261
+ CuSolverBindings.check_status!(status, "cusolverDnXgesvd")
262
+ end
263
+ end
264
+ end
265
+ end
266
+ end
@@ -0,0 +1,122 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Ignis Solver Module
4
+ # Provides GPU-accelerated linear algebra solvers using cuSOLVER
5
+
6
+ require_relative "solver/cusolver_bindings"
7
+ require_relative "solver/lu"
8
+ require_relative "solver/svd"
9
+ require_relative "solver/eigen"
10
+ require_relative "solver/sparse_solver"
11
+
12
+ module Ignis
13
+ module Solver
14
+ class << self
15
+ # Solve linear system Ax = b
16
+ # @param a [NvArray] Coefficient matrix (n x n)
17
+ # @param b [NvArray] Right-hand side (n x nrhs)
18
+ # @return [NvArray] Solution x
19
+ def solve(a, b)
20
+ LU.solve(a, b)
21
+ end
22
+
23
+ # LU decomposition with partial pivoting
24
+ # @param matrix [NvArray] Input matrix (m x n)
25
+ # @return [Hash] { lu:, pivot:, info: }
26
+ def lu(matrix)
27
+ LU.getrf(matrix)
28
+ end
29
+
30
+ # Singular Value Decomposition
31
+ # @param matrix [NvArray] Input matrix (m x n)
32
+ # @param full_matrices [Boolean] Compute full U, V matrices
33
+ # @return [Hash] { u:, s:, vt: }
34
+ def svd(matrix, full_matrices: false)
35
+ SVD.gesvd(matrix, full_matrices: full_matrices)
36
+ end
37
+
38
+ # Compute singular values only
39
+ # @param matrix [NvArray] Input matrix
40
+ # @return [NvArray] Singular values
41
+ def svdvals(matrix)
42
+ SVD.singular_values(matrix)
43
+ end
44
+
45
+ # Eigenvalue decomposition of symmetric/Hermitian matrix
46
+ # @param matrix [NvArray] Symmetric/Hermitian matrix
47
+ # @param eigenvectors [Boolean] Compute eigenvectors
48
+ # @return [Hash] { eigenvalues:, eigenvectors: }
49
+ def eigh(matrix, eigenvectors: true)
50
+ Eigen.eigh(matrix, eigenvectors: eigenvectors)
51
+ end
52
+
53
+ # Eigenvalues of symmetric/Hermitian matrix
54
+ # @param matrix [NvArray] Symmetric/Hermitian matrix
55
+ # @return [NvArray] Eigenvalues
56
+ def eigvalsh(matrix)
57
+ Eigen.eigvalsh(matrix)
58
+ end
59
+
60
+ # Eigenvalue decomposition of general matrix
61
+ # @param matrix [NvArray] General square matrix
62
+ # @return [Hash] { eigenvalues:, eigenvectors: }
63
+ def eig(matrix)
64
+ Eigen.eig(matrix)
65
+ end
66
+
67
+ # Matrix rank via SVD
68
+ # @param matrix [NvArray] Input matrix
69
+ # @param tol [Float, nil] Tolerance
70
+ # @return [Integer] Numerical rank
71
+ def matrix_rank(matrix, tol: nil)
72
+ SVD.rank(matrix, tol: tol)
73
+ end
74
+
75
+ # Condition number via SVD
76
+ # @param matrix [NvArray] Input matrix
77
+ # @return [Float] Condition number
78
+ def cond(matrix)
79
+ SVD.cond(matrix)
80
+ end
81
+
82
+ # Create an LU solver for repeated solves
83
+ # @param matrix [NvArray] Coefficient matrix
84
+ # @return [LUSolver]
85
+ def lu_solver(matrix)
86
+ LUSolver.new(matrix)
87
+ end
88
+
89
+ # Solve sparse linear system Ax = b using cuDSS
90
+ # @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix in CSR format
91
+ # @param b [NvArray] Right-hand side vector/matrix
92
+ # @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd)
93
+ # @return [NvArray] Solution x
94
+ def sparse_solve(sparse_matrix, b, matrix_type: :general)
95
+ SparseSolver.solve(sparse_matrix, b, matrix_type: matrix_type)
96
+ end
97
+
98
+ # Create a sparse solver for repeated solves with same sparsity pattern
99
+ # @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix
100
+ # @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd)
101
+ # @return [SparseSolver]
102
+ def sparse_solver(sparse_matrix, matrix_type: :general)
103
+ SparseSolver.new(sparse_matrix, matrix_type: matrix_type)
104
+ end
105
+
106
+ # Check if cuSOLVER is available
107
+ # @return [Boolean]
108
+ def available?
109
+ CuSolverBindings.ensure_loaded!
110
+ true
111
+ rescue StandardError
112
+ false
113
+ end
114
+
115
+ # Finalize cuSOLVER (release resources)
116
+ # @return [void]
117
+ def finalize!
118
+ CuSolverBindings.finalize!
119
+ end
120
+ end
121
+ end
122
+ end
@@ -0,0 +1,231 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module Sparse
7
+ # cuSPARSE library FFI bindings
8
+ module CuSPARSEBindings
9
+ extend FFI::Library
10
+
11
+ # Sparse matrix formats
12
+ CUSPARSE_INDEX_BASE_ZERO = 0
13
+ CUSPARSE_INDEX_BASE_ONE = 1
14
+
15
+ # Matrix types
16
+ CUSPARSE_MATRIX_TYPE_GENERAL = 0
17
+ CUSPARSE_MATRIX_TYPE_SYMMETRIC = 1
18
+ CUSPARSE_MATRIX_TYPE_HERMITIAN = 2
19
+ CUSPARSE_MATRIX_TYPE_TRIANGULAR = 3
20
+
21
+ # Fill modes
22
+ CUSPARSE_FILL_MODE_LOWER = 0
23
+ CUSPARSE_FILL_MODE_UPPER = 1
24
+
25
+ # Diagonal types
26
+ CUSPARSE_DIAG_TYPE_NON_UNIT = 0
27
+ CUSPARSE_DIAG_TYPE_UNIT = 1
28
+
29
+ # Index types
30
+ CUSPARSE_INDEX_16U = 1
31
+ CUSPARSE_INDEX_32I = 2
32
+ CUSPARSE_INDEX_64I = 3
33
+
34
+ # Sparse matrix-vector operation types
35
+ CUSPARSE_OPERATION_NON_TRANSPOSE = 0
36
+ CUSPARSE_OPERATION_TRANSPOSE = 1
37
+ CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE = 2
38
+
39
+ # SpMV algorithms
40
+ CUSPARSE_SPMV_ALG_DEFAULT = 0
41
+ CUSPARSE_SPMV_CSR_ALG1 = 1
42
+ CUSPARSE_SPMV_CSR_ALG2 = 2
43
+ CUSPARSE_SPMV_COO_ALG1 = 3
44
+ CUSPARSE_SPMV_COO_ALG2 = 4
45
+
46
+ # SpMM algorithms
47
+ CUSPARSE_SPMM_ALG_DEFAULT = 0
48
+ CUSPARSE_SPMM_CSR_ALG1 = 1
49
+ CUSPARSE_SPMM_CSR_ALG2 = 2
50
+ CUSPARSE_SPMM_CSR_ALG3 = 3
51
+ CUSPARSE_SPMM_COO_ALG1 = 4
52
+ CUSPARSE_SPMM_COO_ALG2 = 5
53
+ CUSPARSE_SPMM_COO_ALG3 = 6
54
+ CUSPARSE_SPMM_COO_ALG4 = 7
55
+ CUSPARSE_SPMM_BLOCKED_ELL_ALG1 = 8
56
+
57
+ @loaded = false
58
+ @handle = nil
59
+
60
+ class << self
61
+ # @return [FFI::Pointer, nil] cuSPARSE handle
62
+ attr_accessor :handle
63
+
64
+ # Ensure cuSPARSE is loaded
65
+ # @return [void]
66
+ def ensure_loaded!
67
+ return if @loaded
68
+
69
+ CUDA::LibraryLoader.load_library(:cusparse)
70
+
71
+ cuda_bin = Ignis.configuration.cuda_bin_path
72
+ if cuda_bin
73
+ ffi_lib Dir.glob(File.join(cuda_bin, "cusparse64_*.dll")).max
74
+ else
75
+ ffi_lib "cusparse64_12"
76
+ end
77
+
78
+ attach_cusparse_functions!
79
+ initialize_cusparse!
80
+
81
+ @loaded = true
82
+ end
83
+
84
+ # Get or create cuSPARSE handle
85
+ # @return [FFI::Pointer]
86
+ def get_handle
87
+ ensure_loaded!
88
+ @handle
89
+ end
90
+
91
+ # Finalize cuSPARSE
92
+ # @return [void]
93
+ def finalize!
94
+ return unless @handle
95
+
96
+ cusparseDestroy(@handle)
97
+ @handle = nil
98
+ @loaded = false
99
+ end
100
+
101
+ private
102
+
103
+ # rubocop:disable Metrics/MethodLength
104
+ def attach_cusparse_functions!
105
+ # Handle management
106
+ attach_function :cusparseCreate, [:pointer], :int
107
+ attach_function :cusparseDestroy, [:pointer], :int
108
+ attach_function :cusparseSetStream, [:pointer, :pointer], :int
109
+ attach_function :cusparseGetStream, [:pointer, :pointer], :int
110
+
111
+ # Matrix descriptor (legacy API)
112
+ attach_function :cusparseCreateMatDescr, [:pointer], :int
113
+ attach_function :cusparseDestroyMatDescr, [:pointer], :int
114
+ attach_function :cusparseSetMatType, [:pointer, :int], :int
115
+ attach_function :cusparseGetMatType, [:pointer], :int
116
+ attach_function :cusparseSetMatFillMode, [:pointer, :int], :int
117
+ attach_function :cusparseGetMatFillMode, [:pointer], :int
118
+ attach_function :cusparseSetMatDiagType, [:pointer, :int], :int
119
+ attach_function :cusparseGetMatDiagType, [:pointer], :int
120
+ attach_function :cusparseSetMatIndexBase, [:pointer, :int], :int
121
+ attach_function :cusparseGetMatIndexBase, [:pointer], :int
122
+
123
+ # Generic API - Sparse Matrix Descriptors
124
+ attach_function :cusparseCreateCsr, [
125
+ :pointer, # spMatDescr
126
+ :int64, # rows
127
+ :int64, # cols
128
+ :int64, # nnz
129
+ :pointer, # csrRowOffsets
130
+ :pointer, # csrColInd
131
+ :pointer, # csrValues
132
+ :int, # csrRowOffsetsType
133
+ :int, # csrColIndType
134
+ :int, # idxBase
135
+ :int # valueType
136
+ ], :int
137
+
138
+ attach_function :cusparseCreateCoo, [
139
+ :pointer, # spMatDescr
140
+ :int64, # rows
141
+ :int64, # cols
142
+ :int64, # nnz
143
+ :pointer, # cooRowInd
144
+ :pointer, # cooColInd
145
+ :pointer, # cooValues
146
+ :int, # cooIdxType
147
+ :int, # idxBase
148
+ :int # valueType
149
+ ], :int
150
+
151
+ attach_function :cusparseDestroySpMat, [:pointer], :int
152
+
153
+ # Dense Vector/Matrix Descriptors
154
+ attach_function :cusparseCreateDnVec, [:pointer, :int64, :pointer, :int], :int
155
+ attach_function :cusparseDestroyDnVec, [:pointer], :int
156
+ attach_function :cusparseCreateDnMat, [
157
+ :pointer, :int64, :int64, :int64, :pointer, :int, :int
158
+ ], :int
159
+ attach_function :cusparseDestroyDnMat, [:pointer], :int
160
+
161
+ # SpMV - Sparse Matrix-Vector Multiplication
162
+ attach_function :cusparseSpMV, [
163
+ :pointer, # handle
164
+ :int, # opA
165
+ :pointer, # alpha
166
+ :pointer, # matA
167
+ :pointer, # vecX
168
+ :pointer, # beta
169
+ :pointer, # vecY
170
+ :int, # computeType
171
+ :int, # alg
172
+ :pointer # externalBuffer
173
+ ], :int
174
+
175
+ attach_function :cusparseSpMV_bufferSize, [
176
+ :pointer, :int, :pointer, :pointer, :pointer, :pointer, :pointer, :int, :int, :pointer
177
+ ], :int
178
+
179
+ # SpMM - Sparse Matrix-Matrix Multiplication
180
+ attach_function :cusparseSpMM, [
181
+ :pointer, # handle
182
+ :int, # opA
183
+ :int, # opB
184
+ :pointer, # alpha
185
+ :pointer, # matA
186
+ :pointer, # matB
187
+ :pointer, # beta
188
+ :pointer, # matC
189
+ :int, # computeType
190
+ :int, # alg
191
+ :pointer # externalBuffer
192
+ ], :int
193
+
194
+ attach_function :cusparseSpMM_bufferSize, [
195
+ :pointer, :int, :int, :pointer, :pointer, :pointer, :pointer, :pointer, :int, :int, :pointer
196
+ ], :int
197
+
198
+ # CSR format conversion
199
+ attach_function :cusparseXcoo2csr, [
200
+ :pointer, :pointer, :int, :int, :pointer, :int
201
+ ], :int
202
+ attach_function :cusparseXcsr2coo, [
203
+ :pointer, :pointer, :int, :int, :pointer, :int
204
+ ], :int
205
+ end
206
+ # rubocop:enable Metrics/MethodLength
207
+
208
+ # Initialize cuSPARSE handle
209
+ def initialize_cusparse!
210
+ handle_ptr = FFI::MemoryPointer.new(:pointer)
211
+ status = cusparseCreate(handle_ptr)
212
+
213
+ raise CuSPARSEError, status unless status.zero?
214
+
215
+ @handle = handle_ptr.read_pointer
216
+ Ignis.logger.info("cuSPARSE initialized successfully")
217
+ end
218
+ end
219
+
220
+ # Check cuSPARSE status and raise error if not success
221
+ # @param status [Integer] cuSPARSE status code
222
+ # @param context [String] Context for error message
223
+ # @return [void]
224
+ def self.check_status!(status, context = "cuSPARSE operation")
225
+ return if status.zero?
226
+
227
+ raise CuSPARSEError, status
228
+ end
229
+ end
230
+ end
231
+ end