ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "cusolver_bindings"
4
+
5
+ module Ignis
6
+ module Solver
7
+ # LU Decomposition operations using cuSOLVER
8
+ # Computes P*A = L*U factorization for solving linear systems
9
+ module LU
10
+ class << self
11
+ # Compute LU factorization of a matrix
12
+ # @param matrix [NvArray] Input matrix (m x n)
13
+ # @param overwrite [Boolean] If true, overwrite input matrix with result
14
+ # @return [Hash] Contains :lu (factored matrix), :pivot (pivot indices), :info (status)
15
+ # @raise [CuSolverError] If factorization fails
16
+ def getrf(matrix, overwrite: false)
17
+ CuSolverBindings.ensure_loaded!
18
+
19
+ validate_matrix!(matrix)
20
+ m, n = matrix.shape
21
+ lda = m
22
+
23
+ # Copy matrix if not overwriting
24
+ work_matrix = overwrite ? matrix : matrix.dup
25
+
26
+ # Get workspace size
27
+ lwork_ptr = FFI::MemoryPointer.new(:int)
28
+ get_buffer_size(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda, lwork_ptr)
29
+ lwork = lwork_ptr.read_int
30
+
31
+ # Allocate workspace and pivot array
32
+ workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype))
33
+ pivot = CUDA::Memory.new(([m, n].min) * 4) # int32 pivot indices
34
+ info = CUDA::Memory.new(4) # int32 info
35
+
36
+ # Perform LU factorization (FFI cuSOLVER needs FFI pointers, not Fiddle)
37
+ perform_getrf(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda,
38
+ workspace.ffi_ptr, pivot.ffi_ptr, info.ffi_ptr)
39
+
40
+ # Read info to check for errors
41
+ info_value = read_device_int(info)
42
+ if info_value < 0
43
+ raise CuSolverError.new("LU factorization: parameter #{-info_value} had an illegal value",
44
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
45
+ elsif info_value > 0
46
+ Ignis.logger.warn("LU factorization: U(#{info_value},#{info_value}) is exactly zero. " \
47
+ "The factorization has been completed, but U is singular.")
48
+ end
49
+
50
+ # Synchronize to ensure completion
51
+ CUDA::RuntimeAPI.cudaDeviceSynchronize
52
+
53
+ {
54
+ lu: work_matrix,
55
+ pivot: pivot,
56
+ pivot_size: [m, n].min,
57
+ info: info_value
58
+ }
59
+ ensure
60
+ workspace&.free! if defined?(workspace) && workspace
61
+ info&.free! if defined?(info) && info && info_value
62
+ end
63
+
64
+ # Solve linear system Ax = b using LU factorization
65
+ # @param a [NvArray] Coefficient matrix (n x n) - can be pre-factored LU
66
+ # @param b [NvArray] Right-hand side matrix/vector (n x nrhs)
67
+ # @param pivot [FFI::Pointer, nil] Pivot indices from getrf (nil to compute fresh)
68
+ # @param trans [Symbol] :none, :transpose, or :conjugate
69
+ # @return [NvArray] Solution x
70
+ # @raise [CuSolverError] If solve fails
71
+ def getrs(a, b, pivot: nil, trans: :none)
72
+ CuSolverBindings.ensure_loaded!
73
+
74
+ validate_matrix!(a)
75
+ validate_matrix!(b)
76
+
77
+ n = a.shape[0]
78
+ raise ArgumentError, "Matrix A must be square" unless a.shape[0] == a.shape[1]
79
+ raise ArgumentError, "Matrix dimensions mismatch" unless b.shape[0] == n
80
+
81
+ nrhs = b.shape.length > 1 ? b.shape[1] : 1
82
+ lda = n
83
+ ldb = n
84
+
85
+ # If no pivot provided, compute LU factorization first
86
+ if pivot.nil?
87
+ result = getrf(a)
88
+ lu_matrix = result[:lu]
89
+ pivot = result[:pivot]
90
+ else
91
+ lu_matrix = a
92
+ end
93
+
94
+ # Copy b to output
95
+ x = b.dup
96
+
97
+ # Allocate info
98
+ info = CUDA::Memory.new(4)
99
+
100
+ # Map transpose option
101
+ trans_op = case trans
102
+ when :none then CuSolverBindings::CUBLAS_OP_N
103
+ when :transpose then CuSolverBindings::CUBLAS_OP_T
104
+ when :conjugate then CuSolverBindings::CUBLAS_OP_C
105
+ else CuSolverBindings::CUBLAS_OP_N
106
+ end
107
+
108
+ # Perform solve (FFI cuSOLVER needs FFI pointers, not Fiddle)
109
+ pivot_ptr = pivot.respond_to?(:ffi_ptr) ? pivot.ffi_ptr : pivot
110
+ perform_getrs(a.dtype, trans_op, n, nrhs, lu_matrix.device_ffi_ptr, lda,
111
+ pivot_ptr, x.device_ffi_ptr, ldb, info.ffi_ptr)
112
+
113
+ # Check info
114
+ info_value = read_device_int(info)
115
+ if info_value < 0
116
+ raise CuSolverError.new("LU solve: parameter #{-info_value} had an illegal value",
117
+ cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
118
+ end
119
+
120
+ CUDA::RuntimeAPI.cudaDeviceSynchronize
121
+
122
+ x
123
+ ensure
124
+ info&.free! if defined?(info) && info
125
+ end
126
+
127
+ # Solve linear system Ax = b (convenience method)
128
+ # @param a [NvArray] Coefficient matrix (n x n)
129
+ # @param b [NvArray] Right-hand side (n x nrhs)
130
+ # @return [NvArray] Solution x
131
+ def solve(a, b)
132
+ getrs(a, b)
133
+ end
134
+
135
+ private
136
+
137
+ # Validate that input is a valid matrix
138
+ # @param matrix [NvArray]
139
+ # @raise [ArgumentError] If matrix is invalid
140
+ def validate_matrix!(matrix)
141
+ raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
142
+ raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
143
+ raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
144
+ end
145
+
146
+ # Get element size for dtype
147
+ # @param dtype [Symbol]
148
+ # @return [Integer]
149
+ def dtype_size(dtype)
150
+ case dtype
151
+ when :float32 then 4
152
+ when :float64 then 8
153
+ when :complex64 then 8
154
+ when :complex128 then 16
155
+ else 4
156
+ end
157
+ end
158
+
159
+ # Read an int32 from a CUDA::Memory (info/pivot) via the Fiddle-safe
160
+ # CUDA::Memory#copy_to_host path (avoids mixing FFI and Fiddle pointers).
161
+ # @param mem [CUDA::Memory]
162
+ # @return [Integer]
163
+ def read_device_int(mem)
164
+ buf = mem.copy_to_host(count: 4)
165
+ buf[0, 4].unpack1("l")
166
+ end
167
+
168
+ # Get buffer size for LU factorization — dispatched by dtype (was hardcoded
169
+ # to Sgetrf, so float64/complex matrices sized their workspace for float32).
170
+ def get_buffer_size(dtype, m, n, a_ptr, lda, lwork_ptr)
171
+ h = CuSolverBindings.get_handle
172
+ status = case dtype
173
+ when :float32 then CuSolverBindings.cusolverDnSgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
174
+ when :float64 then CuSolverBindings.cusolverDnDgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
175
+ when :complex64 then CuSolverBindings.cusolverDnCgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
176
+ when :complex128 then CuSolverBindings.cusolverDnZgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
177
+ else raise UnsupportedDTypeError.new(dtype, operation: "LU factorization")
178
+ end
179
+ CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrf_bufferSize")
180
+ end
181
+
182
+ # Perform LU factorization — dispatched by dtype.
183
+ def perform_getrf(dtype, m, n, a_ptr, lda, workspace, pivot, info)
184
+ h = CuSolverBindings.get_handle
185
+ status = case dtype
186
+ when :float32 then CuSolverBindings.cusolverDnSgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
187
+ when :float64 then CuSolverBindings.cusolverDnDgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
188
+ when :complex64 then CuSolverBindings.cusolverDnCgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
189
+ when :complex128 then CuSolverBindings.cusolverDnZgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
190
+ else raise UnsupportedDTypeError.new(dtype, operation: "LU factorization")
191
+ end
192
+ CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrf")
193
+ end
194
+
195
+ # Perform LU solve — dispatched by dtype.
196
+ def perform_getrs(dtype, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
197
+ h = CuSolverBindings.get_handle
198
+ status = case dtype
199
+ when :float32 then CuSolverBindings.cusolverDnSgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
200
+ when :float64 then CuSolverBindings.cusolverDnDgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
201
+ when :complex64 then CuSolverBindings.cusolverDnCgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
202
+ when :complex128 then CuSolverBindings.cusolverDnZgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
203
+ else raise UnsupportedDTypeError.new(dtype, operation: "LU solve")
204
+ end
205
+ CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrs")
206
+ end
207
+ end
208
+ end
209
+
210
+ # LU Solver plan for repeated solves with the same matrix
211
+ # Caches the LU factorization for efficiency
212
+ class LUSolver
213
+ # @return [NvArray] Original matrix
214
+ attr_reader :matrix
215
+
216
+ # @return [NvArray] LU factored matrix
217
+ attr_reader :lu
218
+
219
+ # @return [FFI::Pointer] Pivot indices
220
+ attr_reader :pivot
221
+
222
+ # @return [Integer] Matrix dimension
223
+ attr_reader :n
224
+
225
+ # Create an LU solver for the given matrix
226
+ # @param matrix [NvArray] Coefficient matrix (n x n)
227
+ def initialize(matrix)
228
+ @matrix = matrix
229
+ validate!
230
+ @n = matrix.shape[0]
231
+ factorize!
232
+ end
233
+
234
+ # Solve Ax = b using the cached factorization
235
+ # @param b [NvArray] Right-hand side (n x nrhs)
236
+ # @param trans [Symbol] :none, :transpose, or :conjugate
237
+ # @return [NvArray] Solution x
238
+ def solve(b, trans: :none)
239
+ LU.getrs(@lu, b, pivot: @pivot, trans: trans)
240
+ end
241
+
242
+ # Release resources
243
+ # @return [void]
244
+ def destroy!
245
+ @pivot&.free!
246
+ @pivot = nil
247
+ @factorized = false
248
+ end
249
+
250
+ private
251
+
252
+ def validate!
253
+ raise ArgumentError, "Matrix must be square" unless @matrix.shape[0] == @matrix.shape[1]
254
+ raise ArgumentError, "Matrix must be 2D" unless @matrix.shape.length == 2
255
+ end
256
+
257
+ def factorize!
258
+ result = LU.getrf(@matrix.dup)
259
+ @lu = result[:lu]
260
+ @pivot = result[:pivot]
261
+ @factorized = true
262
+ end
263
+ end
264
+ end
265
+ end
@@ -0,0 +1,429 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "cudss_bindings"
4
+
5
+ module Ignis
6
+ module Solver
7
+ # High-level sparse linear solver using cuDSS
8
+ # Solves sparse systems of the form Ax = b using direct methods
9
+ #
10
+ # cuDSS phases:
11
+ # 1. Analysis: Reordering and symbolic factorization
12
+ # 2. Factorization: Numerical factorization (LU, Cholesky, LDL)
13
+ # 3. Solve: Forward/backward substitution
14
+ #
15
+ # @example Solve a sparse linear system
16
+ # solver = Ignis::Solver::SparseSolver.new(sparse_matrix)
17
+ # solver.analyze!
18
+ # solver.factor!
19
+ # x = solver.solve(b)
20
+ # solver.destroy!
21
+ #
22
+ # @example Quick solve (combines all phases)
23
+ # x = Ignis::Solver::SparseSolver.solve(sparse_matrix, b)
24
+ class SparseSolver
25
+ # Matrix types for cuDSS
26
+ MATRIX_TYPE_GENERAL = 0
27
+ MATRIX_TYPE_SYMMETRIC = 1
28
+ MATRIX_TYPE_HERMITIAN = 2
29
+ MATRIX_TYPE_SPD = 3 # Symmetric Positive Definite
30
+ MATRIX_TYPE_HPD = 4 # Hermitian Positive Definite
31
+
32
+ # Index base
33
+ INDEX_BASE_ZERO = 0
34
+ INDEX_BASE_ONE = 1
35
+
36
+ # CUDA library data types (from library_types.h)
37
+ # cuDSS uses these for both value and index types
38
+ CUDA_R_32F = 0 # float
39
+ CUDA_R_64F = 1 # double
40
+ CUDA_R_16F = 2 # half
41
+ CUDA_C_32F = 4 # cuComplex
42
+ CUDA_C_64F = 5 # cuDoubleComplex
43
+ CUDA_R_32I = 10 # int32
44
+ CUDA_R_64I = 24 # int64
45
+
46
+ # @return [Sparse::SparseMatrix] The sparse coefficient matrix
47
+ attr_reader :matrix
48
+
49
+ # @return [Symbol] Matrix type (:general, :symmetric, :spd)
50
+ attr_reader :matrix_type
51
+
52
+ # @return [Boolean] Whether analysis has been performed
53
+ attr_reader :analyzed
54
+
55
+ # @return [Boolean] Whether factorization has been performed
56
+ attr_reader :factored
57
+
58
+ # Convenience method to solve a sparse system in one call
59
+ # @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix A
60
+ # @param b [NvArray] Right-hand side vector/matrix b
61
+ # @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd)
62
+ # @return [NvArray] Solution vector/matrix x
63
+ def self.solve(sparse_matrix, b, matrix_type: :general)
64
+ solver = new(sparse_matrix, matrix_type: matrix_type)
65
+ begin
66
+ solver.analyze!
67
+ solver.factor!
68
+ solver.solve(b)
69
+ ensure
70
+ solver.destroy!
71
+ end
72
+ end
73
+
74
+ # Initialize sparse solver
75
+ # @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix A in CSR format
76
+ # @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd, :hermitian, :hpd)
77
+ def initialize(sparse_matrix, matrix_type: :general)
78
+ validate_matrix!(sparse_matrix)
79
+ @matrix = sparse_matrix
80
+ @matrix_type = matrix_type
81
+ @analyzed = false
82
+ @factored = false
83
+ @handle = nil
84
+ @config = nil
85
+ @data = nil
86
+ @matrix_wrapper = nil
87
+ @placeholder_x = nil
88
+ @placeholder_b = nil
89
+ @stream = nil
90
+
91
+ initialize_cudss!
92
+ end
93
+
94
+ # Perform analysis phase (reordering + symbolic factorization)
95
+ # @return [self]
96
+ def analyze!
97
+ raise StateError, "cuDSS not initialized" unless @handle
98
+
99
+ phase = CuDSSBindings::CUDSS_PHASE_ANALYSIS
100
+ status = CuDSSBindings.cudssExecute(
101
+ @handle, phase, @config, @data, @matrix_wrapper, @placeholder_x, @placeholder_b
102
+ )
103
+ CuDSSBindings.check_status!(status, "cuDSS analysis")
104
+
105
+ @analyzed = true
106
+ self
107
+ end
108
+
109
+ # Perform numerical factorization
110
+ # @return [self]
111
+ def factor!
112
+ raise StateError, "Must call analyze! before factor!" unless @analyzed
113
+
114
+ phase = CuDSSBindings::CUDSS_PHASE_FACTORIZATION
115
+ status = CuDSSBindings.cudssExecute(
116
+ @handle, phase, @config, @data, @matrix_wrapper, @placeholder_x, @placeholder_b
117
+ )
118
+ CuDSSBindings.check_status!(status, "cuDSS factorization")
119
+
120
+ @factored = true
121
+ self
122
+ end
123
+
124
+ # Solve the system Ax = b
125
+ # @param b [NvArray] Right-hand side vector/matrix
126
+ # @return [NvArray] Solution vector/matrix x
127
+ def solve(b)
128
+ raise StateError, "Must call factor! before solve!" unless @factored
129
+
130
+ validate_rhs!(b)
131
+
132
+ # Ensure b is on device
133
+ b_dev = b.on_device? ? b : b.to_device
134
+
135
+ # Create output matrix
136
+ x = NvArray.zeros(b.shape, dtype: b.dtype, device: @matrix.device_index)
137
+ x = x.to_device
138
+
139
+ # Create cuDSS matrix wrappers for b and x
140
+ b_wrapper = create_dense_matrix_wrapper(b_dev)
141
+ x_wrapper = create_dense_matrix_wrapper(x)
142
+
143
+ # Execute solve phase
144
+ phase = CuDSSBindings::CUDSS_PHASE_SOLVE
145
+ status = CuDSSBindings.cudssExecute(
146
+ @handle, phase, @config, @data, @matrix_wrapper, x_wrapper, b_wrapper
147
+ )
148
+ CuDSSBindings.check_status!(status, "cuDSS solve")
149
+
150
+ # Cleanup temporary wrappers
151
+ CuDSSBindings.cudssMatrixDestroy(b_wrapper)
152
+ CuDSSBindings.cudssMatrixDestroy(x_wrapper)
153
+
154
+ x
155
+ end
156
+
157
+ # Refactor with updated numerical values (same sparsity pattern)
158
+ # Useful when solving multiple systems with same structure but different values
159
+ # @return [self]
160
+ def refactor!
161
+ raise StateError, "Must call factor! before refactor!" unless @factored
162
+
163
+ phase = CuDSSBindings::CUDSS_PHASE_REFACTORIZATION
164
+ status = CuDSSBindings.cudssExecute(
165
+ @handle, phase, @config, @data, @matrix_wrapper, nil, nil
166
+ )
167
+ CuDSSBindings.check_status!(status, "cuDSS refactorization")
168
+
169
+ self
170
+ end
171
+
172
+ # Release all cuDSS resources
173
+ # @return [void]
174
+ def destroy!
175
+ if @placeholder_x
176
+ CuDSSBindings.cudssMatrixDestroy(@placeholder_x)
177
+ @placeholder_x = nil
178
+ end
179
+
180
+ if @placeholder_b
181
+ CuDSSBindings.cudssMatrixDestroy(@placeholder_b)
182
+ @placeholder_b = nil
183
+ end
184
+
185
+ if @matrix_wrapper
186
+ CuDSSBindings.cudssMatrixDestroy(@matrix_wrapper)
187
+ @matrix_wrapper = nil
188
+ end
189
+
190
+ if @data && @handle
191
+ CuDSSBindings.cudssDataDestroy(@handle, @data)
192
+ @data = nil
193
+ end
194
+
195
+ if @config
196
+ CuDSSBindings.cudssConfigDestroy(@config)
197
+ @config = nil
198
+ end
199
+
200
+ if @handle
201
+ CuDSSBindings.cudssDestroy(@handle)
202
+ @handle = nil
203
+ end
204
+
205
+ @analyzed = false
206
+ @factored = false
207
+ end
208
+
209
+ private
210
+
211
+ # Validate input sparse matrix
212
+ def validate_matrix!(matrix)
213
+ unless matrix.is_a?(Sparse::SparseMatrix)
214
+ raise ArgumentError, "Expected SparseMatrix, got #{matrix.class}"
215
+ end
216
+
217
+ unless matrix.format == :csr
218
+ raise ArgumentError, "cuDSS requires CSR format, got #{matrix.format}"
219
+ end
220
+
221
+ unless %i[float32 float64 complex64 complex128].include?(matrix.dtype)
222
+ raise UnsupportedDTypeError.new(matrix.dtype, operation: "sparse solve")
223
+ end
224
+ end
225
+
226
+ # Validate right-hand side
227
+ def validate_rhs!(b)
228
+ unless b.is_a?(NvArray)
229
+ raise ArgumentError, "Expected NvArray, got #{b.class}"
230
+ end
231
+
232
+ n = @matrix.shape[0]
233
+ if b.ndim == 1
234
+ unless b.shape[0] == n
235
+ raise DimensionError, "RHS size (#{b.shape[0]}) doesn't match matrix rows (#{n})"
236
+ end
237
+ elsif b.ndim == 2
238
+ unless b.shape[0] == n
239
+ raise DimensionError, "RHS rows (#{b.shape[0]}) don't match matrix rows (#{n})"
240
+ end
241
+ else
242
+ raise DimensionError, "RHS must be 1D or 2D, got #{b.ndim}D"
243
+ end
244
+ end
245
+
246
+ # Initialize cuDSS handle and create matrix wrapper
247
+ def initialize_cudss!
248
+ CuDSSBindings.ensure_loaded!
249
+
250
+ # Create handle
251
+ handle_ptr = FFI::MemoryPointer.new(:pointer)
252
+ status = CuDSSBindings.cudssCreate(handle_ptr)
253
+ CuDSSBindings.check_status!(status, "cuDSS create")
254
+ @handle = handle_ptr.read_pointer
255
+
256
+ # Set CUDA stream (required before analyze)
257
+ setup_cuda_stream!
258
+
259
+ # Create config and data objects
260
+ @config = create_config_object
261
+ @data = create_data_object
262
+
263
+ # Create sparse matrix wrapper
264
+ @matrix_wrapper = create_csr_matrix_wrapper
265
+
266
+ # Create placeholder solution/rhs wrappers (required for analyze/factor phases)
267
+ create_placeholder_vectors!
268
+ end
269
+
270
+ # Set up CUDA stream for cuDSS operations
271
+ # @return [void]
272
+ def setup_cuda_stream!
273
+ @stream_obj = CUDA::Stream.new
274
+ @stream = @stream_obj.handle
275
+
276
+ status = CuDSSBindings.cudssSetStream(@handle, @stream)
277
+ CuDSSBindings.check_status!(status, "cuDSS set stream")
278
+ end
279
+
280
+ # Create placeholder solution/rhs vectors for analyze/factor phases
281
+ # cuDSS requires valid matrix wrappers even if values aren't used during these phases
282
+ # @return [void]
283
+ def create_placeholder_vectors!
284
+ n = @matrix.shape[0]
285
+ value_type = cudss_data_type(@matrix.dtype)
286
+ layout = 0 # CUDSS_LAYOUT_COL_MAJOR (0 = column major, 1 = row major not supported)
287
+
288
+ # Allocate device memory for placeholder vectors
289
+ elem_size = @matrix.dtype == :float64 ? 8 : 4
290
+ @placeholder_x_mem = CUDA::Memory.new(n * elem_size)
291
+ @placeholder_b_mem = CUDA::Memory.new(n * elem_size)
292
+
293
+ # Create dense matrix wrappers for x (solution) and b (rhs)
294
+ x_ptr = FFI::MemoryPointer.new(:pointer)
295
+ status = CuDSSBindings.cudssMatrixCreateDn(
296
+ x_ptr, n, 1, n, @placeholder_x_mem.device_ptr, value_type, layout
297
+ )
298
+ CuDSSBindings.check_status!(status, "cuDSS create placeholder x")
299
+ @placeholder_x = x_ptr.read_pointer
300
+
301
+ b_ptr = FFI::MemoryPointer.new(:pointer)
302
+ status = CuDSSBindings.cudssMatrixCreateDn(
303
+ b_ptr, n, 1, n, @placeholder_b_mem.device_ptr, value_type, layout
304
+ )
305
+ CuDSSBindings.check_status!(status, "cuDSS create placeholder b")
306
+ @placeholder_b = b_ptr.read_pointer
307
+ end
308
+
309
+ # Create cuDSS config object
310
+ # @return [FFI::Pointer]
311
+ def create_config_object
312
+ config_ptr = FFI::MemoryPointer.new(:pointer)
313
+
314
+ status = CuDSSBindings.cudssConfigCreate(config_ptr)
315
+ CuDSSBindings.check_status!(status, "cuDSS config create")
316
+
317
+ config_ptr.read_pointer
318
+ end
319
+
320
+ # Create cuDSS data object
321
+ # @return [FFI::Pointer]
322
+ def create_data_object
323
+ data_ptr = FFI::MemoryPointer.new(:pointer)
324
+
325
+ status = CuDSSBindings.cudssDataCreate(@handle, data_ptr)
326
+ CuDSSBindings.check_status!(status, "cuDSS data create")
327
+
328
+ data_ptr.read_pointer
329
+ end
330
+
331
+ # Create CSR matrix wrapper for cuDSS
332
+ # @return [FFI::Pointer]
333
+ def create_csr_matrix_wrapper
334
+ matrix_ptr = FFI::MemoryPointer.new(:pointer)
335
+
336
+ n_rows = @matrix.shape[0]
337
+ n_cols = @matrix.shape[1]
338
+ nnz = @matrix.nnz
339
+
340
+ # Ensure CUDA context is initialized (cuDSS requires active context)
341
+ Ignis.synchronize
342
+
343
+ # Ensure matrix is on device - cuDSS requires device pointers
344
+ @matrix.to_device unless @matrix.on_device?
345
+
346
+ # Verify we have valid device pointers
347
+ raise StateError, "Matrix not on device" unless @matrix.on_device?
348
+ raise StateError, "row_ptr is nil" if @matrix.row_ptr.nil?
349
+ raise StateError, "col_indices is nil" if @matrix.col_indices.nil?
350
+ raise StateError, "values is nil" if @matrix.values.nil?
351
+
352
+ # Get cuDSS data type
353
+ value_type = cudss_data_type(@matrix.dtype)
354
+ index_type = CUDA_R_32I # int32 indices
355
+
356
+ # Matrix structure type
357
+ structure = case @matrix_type
358
+ when :general then MATRIX_TYPE_GENERAL
359
+ when :symmetric then MATRIX_TYPE_SYMMETRIC
360
+ when :hermitian then MATRIX_TYPE_HERMITIAN
361
+ when :spd then MATRIX_TYPE_SPD
362
+ when :hpd then MATRIX_TYPE_HPD
363
+ else MATRIX_TYPE_GENERAL
364
+ end
365
+
366
+ # mview: CUDSS_MVIEW_FULL = 0
367
+ mview_full = 0
368
+
369
+ status = CuDSSBindings.cudssMatrixCreateCsr(
370
+ matrix_ptr,
371
+ n_rows, n_cols, nnz,
372
+ @matrix.row_ptr.device_ptr,
373
+ FFI::Pointer::NULL, # row_end - NULL for standard CSR
374
+ @matrix.col_indices.device_ptr,
375
+ @matrix.values.device_ptr,
376
+ index_type,
377
+ value_type,
378
+ structure,
379
+ mview_full, # mview - full matrix view
380
+ INDEX_BASE_ZERO # indexBase - zero-based indexing
381
+ )
382
+ CuDSSBindings.check_status!(status, "cuDSS CSR matrix create")
383
+
384
+ matrix_ptr.read_pointer
385
+ end
386
+
387
+ # Create dense matrix wrapper for cuDSS
388
+ # @return [FFI::Pointer]
389
+ def create_dense_matrix_wrapper(array)
390
+ matrix_ptr = FFI::MemoryPointer.new(:pointer)
391
+
392
+ n_rows = array.ndim == 1 ? array.shape[0] : array.shape[0]
393
+ n_cols = array.ndim == 1 ? 1 : array.shape[1]
394
+ ld = n_rows # Leading dimension (column-major)
395
+
396
+ value_type = cudss_data_type(array.dtype)
397
+ layout = 0 # CUDSS_LAYOUT_COL_MAJOR (0 = column major, 1 = row major not supported)
398
+
399
+ status = CuDSSBindings.cudssMatrixCreateDn(
400
+ matrix_ptr,
401
+ n_rows, n_cols, ld,
402
+ array.device_ptr,
403
+ value_type,
404
+ layout
405
+ )
406
+ CuDSSBindings.check_status!(status, "cuDSS dense matrix create")
407
+
408
+ matrix_ptr.read_pointer
409
+ end
410
+
411
+ # Convert Ignis dtype to CUDA library data type
412
+ # @param dtype [Symbol]
413
+ # @return [Integer]
414
+ def cudss_data_type(dtype)
415
+ case dtype
416
+ when :float32 then CUDA_R_32F
417
+ when :float64 then CUDA_R_64F
418
+ when :complex64 then CUDA_C_32F
419
+ when :complex128 then CUDA_C_64F
420
+ else
421
+ raise UnsupportedDTypeError.new(dtype, operation: "cuDSS")
422
+ end
423
+ end
424
+ end
425
+
426
+ # State error for solver operations
427
+ class StateError < StandardError; end
428
+ end
429
+ end