ignis-numerics 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-numerics.rb +62 -0
- data/lib/nvruby/array.rb +646 -0
- data/lib/nvruby/fft/cufft_bindings.rb +134 -0
- data/lib/nvruby/fft/fft_plan.rb +288 -0
- data/lib/nvruby/fft/operations.rb +364 -0
- data/lib/nvruby/linalg/cutensor_bindings.rb +107 -0
- data/lib/nvruby/mathdx/fft_kernel.rb +258 -0
- data/lib/nvruby/mathdx/gemm_kernel.rb +293 -0
- data/lib/nvruby/mathdx.rb +73 -0
- data/lib/nvruby/random/curand_bindings.rb +115 -0
- data/lib/nvruby/random/generator.rb +305 -0
- data/lib/nvruby/solver/amgx_bindings.rb +172 -0
- data/lib/nvruby/solver/amgx_config.rb +142 -0
- data/lib/nvruby/solver/amgx_solver.rb +251 -0
- data/lib/nvruby/solver/cudss_bindings.rb +115 -0
- data/lib/nvruby/solver/cusolver_bindings.rb +358 -0
- data/lib/nvruby/solver/eigen.rb +226 -0
- data/lib/nvruby/solver/lu.rb +265 -0
- data/lib/nvruby/solver/sparse_solver.rb +429 -0
- data/lib/nvruby/solver/svd.rb +266 -0
- data/lib/nvruby/solver.rb +122 -0
- data/lib/nvruby/sparse/cusparse_bindings.rb +231 -0
- data/lib/nvruby/sparse/sparse_matrix.rb +456 -0
- data/lib/nvruby/tensor/contraction.rb +218 -0
- data/lib/nvruby/tensor.rb +42 -0
- metadata +85 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "cusolver_bindings"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Solver
|
|
7
|
+
# LU Decomposition operations using cuSOLVER
|
|
8
|
+
# Computes P*A = L*U factorization for solving linear systems
|
|
9
|
+
module LU
|
|
10
|
+
class << self
|
|
11
|
+
# Compute LU factorization of a matrix
|
|
12
|
+
# @param matrix [NvArray] Input matrix (m x n)
|
|
13
|
+
# @param overwrite [Boolean] If true, overwrite input matrix with result
|
|
14
|
+
# @return [Hash] Contains :lu (factored matrix), :pivot (pivot indices), :info (status)
|
|
15
|
+
# @raise [CuSolverError] If factorization fails
|
|
16
|
+
def getrf(matrix, overwrite: false)
|
|
17
|
+
CuSolverBindings.ensure_loaded!
|
|
18
|
+
|
|
19
|
+
validate_matrix!(matrix)
|
|
20
|
+
m, n = matrix.shape
|
|
21
|
+
lda = m
|
|
22
|
+
|
|
23
|
+
# Copy matrix if not overwriting
|
|
24
|
+
work_matrix = overwrite ? matrix : matrix.dup
|
|
25
|
+
|
|
26
|
+
# Get workspace size
|
|
27
|
+
lwork_ptr = FFI::MemoryPointer.new(:int)
|
|
28
|
+
get_buffer_size(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda, lwork_ptr)
|
|
29
|
+
lwork = lwork_ptr.read_int
|
|
30
|
+
|
|
31
|
+
# Allocate workspace and pivot array
|
|
32
|
+
workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype))
|
|
33
|
+
pivot = CUDA::Memory.new(([m, n].min) * 4) # int32 pivot indices
|
|
34
|
+
info = CUDA::Memory.new(4) # int32 info
|
|
35
|
+
|
|
36
|
+
# Perform LU factorization (FFI cuSOLVER needs FFI pointers, not Fiddle)
|
|
37
|
+
perform_getrf(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda,
|
|
38
|
+
workspace.ffi_ptr, pivot.ffi_ptr, info.ffi_ptr)
|
|
39
|
+
|
|
40
|
+
# Read info to check for errors
|
|
41
|
+
info_value = read_device_int(info)
|
|
42
|
+
if info_value < 0
|
|
43
|
+
raise CuSolverError.new("LU factorization: parameter #{-info_value} had an illegal value",
|
|
44
|
+
cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
|
|
45
|
+
elsif info_value > 0
|
|
46
|
+
Ignis.logger.warn("LU factorization: U(#{info_value},#{info_value}) is exactly zero. " \
|
|
47
|
+
"The factorization has been completed, but U is singular.")
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Synchronize to ensure completion
|
|
51
|
+
CUDA::RuntimeAPI.cudaDeviceSynchronize
|
|
52
|
+
|
|
53
|
+
{
|
|
54
|
+
lu: work_matrix,
|
|
55
|
+
pivot: pivot,
|
|
56
|
+
pivot_size: [m, n].min,
|
|
57
|
+
info: info_value
|
|
58
|
+
}
|
|
59
|
+
ensure
|
|
60
|
+
workspace&.free! if defined?(workspace) && workspace
|
|
61
|
+
info&.free! if defined?(info) && info && info_value
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
# Solve linear system Ax = b using LU factorization
|
|
65
|
+
# @param a [NvArray] Coefficient matrix (n x n) - can be pre-factored LU
|
|
66
|
+
# @param b [NvArray] Right-hand side matrix/vector (n x nrhs)
|
|
67
|
+
# @param pivot [FFI::Pointer, nil] Pivot indices from getrf (nil to compute fresh)
|
|
68
|
+
# @param trans [Symbol] :none, :transpose, or :conjugate
|
|
69
|
+
# @return [NvArray] Solution x
|
|
70
|
+
# @raise [CuSolverError] If solve fails
|
|
71
|
+
def getrs(a, b, pivot: nil, trans: :none)
|
|
72
|
+
CuSolverBindings.ensure_loaded!
|
|
73
|
+
|
|
74
|
+
validate_matrix!(a)
|
|
75
|
+
validate_matrix!(b)
|
|
76
|
+
|
|
77
|
+
n = a.shape[0]
|
|
78
|
+
raise ArgumentError, "Matrix A must be square" unless a.shape[0] == a.shape[1]
|
|
79
|
+
raise ArgumentError, "Matrix dimensions mismatch" unless b.shape[0] == n
|
|
80
|
+
|
|
81
|
+
nrhs = b.shape.length > 1 ? b.shape[1] : 1
|
|
82
|
+
lda = n
|
|
83
|
+
ldb = n
|
|
84
|
+
|
|
85
|
+
# If no pivot provided, compute LU factorization first
|
|
86
|
+
if pivot.nil?
|
|
87
|
+
result = getrf(a)
|
|
88
|
+
lu_matrix = result[:lu]
|
|
89
|
+
pivot = result[:pivot]
|
|
90
|
+
else
|
|
91
|
+
lu_matrix = a
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Copy b to output
|
|
95
|
+
x = b.dup
|
|
96
|
+
|
|
97
|
+
# Allocate info
|
|
98
|
+
info = CUDA::Memory.new(4)
|
|
99
|
+
|
|
100
|
+
# Map transpose option
|
|
101
|
+
trans_op = case trans
|
|
102
|
+
when :none then CuSolverBindings::CUBLAS_OP_N
|
|
103
|
+
when :transpose then CuSolverBindings::CUBLAS_OP_T
|
|
104
|
+
when :conjugate then CuSolverBindings::CUBLAS_OP_C
|
|
105
|
+
else CuSolverBindings::CUBLAS_OP_N
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
# Perform solve (FFI cuSOLVER needs FFI pointers, not Fiddle)
|
|
109
|
+
pivot_ptr = pivot.respond_to?(:ffi_ptr) ? pivot.ffi_ptr : pivot
|
|
110
|
+
perform_getrs(a.dtype, trans_op, n, nrhs, lu_matrix.device_ffi_ptr, lda,
|
|
111
|
+
pivot_ptr, x.device_ffi_ptr, ldb, info.ffi_ptr)
|
|
112
|
+
|
|
113
|
+
# Check info
|
|
114
|
+
info_value = read_device_int(info)
|
|
115
|
+
if info_value < 0
|
|
116
|
+
raise CuSolverError.new("LU solve: parameter #{-info_value} had an illegal value",
|
|
117
|
+
cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
CUDA::RuntimeAPI.cudaDeviceSynchronize
|
|
121
|
+
|
|
122
|
+
x
|
|
123
|
+
ensure
|
|
124
|
+
info&.free! if defined?(info) && info
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
# Solve linear system Ax = b (convenience method)
|
|
128
|
+
# @param a [NvArray] Coefficient matrix (n x n)
|
|
129
|
+
# @param b [NvArray] Right-hand side (n x nrhs)
|
|
130
|
+
# @return [NvArray] Solution x
|
|
131
|
+
def solve(a, b)
|
|
132
|
+
getrs(a, b)
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
private
|
|
136
|
+
|
|
137
|
+
# Validate that input is a valid matrix
|
|
138
|
+
# @param matrix [NvArray]
|
|
139
|
+
# @raise [ArgumentError] If matrix is invalid
|
|
140
|
+
def validate_matrix!(matrix)
|
|
141
|
+
raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
|
|
142
|
+
raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
|
|
143
|
+
raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
# Get element size for dtype
|
|
147
|
+
# @param dtype [Symbol]
|
|
148
|
+
# @return [Integer]
|
|
149
|
+
def dtype_size(dtype)
|
|
150
|
+
case dtype
|
|
151
|
+
when :float32 then 4
|
|
152
|
+
when :float64 then 8
|
|
153
|
+
when :complex64 then 8
|
|
154
|
+
when :complex128 then 16
|
|
155
|
+
else 4
|
|
156
|
+
end
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
# Read an int32 from a CUDA::Memory (info/pivot) via the Fiddle-safe
|
|
160
|
+
# CUDA::Memory#copy_to_host path (avoids mixing FFI and Fiddle pointers).
|
|
161
|
+
# @param mem [CUDA::Memory]
|
|
162
|
+
# @return [Integer]
|
|
163
|
+
def read_device_int(mem)
|
|
164
|
+
buf = mem.copy_to_host(count: 4)
|
|
165
|
+
buf[0, 4].unpack1("l")
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
# Get buffer size for LU factorization — dispatched by dtype (was hardcoded
|
|
169
|
+
# to Sgetrf, so float64/complex matrices sized their workspace for float32).
|
|
170
|
+
def get_buffer_size(dtype, m, n, a_ptr, lda, lwork_ptr)
|
|
171
|
+
h = CuSolverBindings.get_handle
|
|
172
|
+
status = case dtype
|
|
173
|
+
when :float32 then CuSolverBindings.cusolverDnSgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
|
|
174
|
+
when :float64 then CuSolverBindings.cusolverDnDgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
|
|
175
|
+
when :complex64 then CuSolverBindings.cusolverDnCgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
|
|
176
|
+
when :complex128 then CuSolverBindings.cusolverDnZgetrf_bufferSize(h, m, n, a_ptr, lda, lwork_ptr)
|
|
177
|
+
else raise UnsupportedDTypeError.new(dtype, operation: "LU factorization")
|
|
178
|
+
end
|
|
179
|
+
CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrf_bufferSize")
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
# Perform LU factorization — dispatched by dtype.
|
|
183
|
+
def perform_getrf(dtype, m, n, a_ptr, lda, workspace, pivot, info)
|
|
184
|
+
h = CuSolverBindings.get_handle
|
|
185
|
+
status = case dtype
|
|
186
|
+
when :float32 then CuSolverBindings.cusolverDnSgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
|
|
187
|
+
when :float64 then CuSolverBindings.cusolverDnDgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
|
|
188
|
+
when :complex64 then CuSolverBindings.cusolverDnCgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
|
|
189
|
+
when :complex128 then CuSolverBindings.cusolverDnZgetrf(h, m, n, a_ptr, lda, workspace, pivot, info)
|
|
190
|
+
else raise UnsupportedDTypeError.new(dtype, operation: "LU factorization")
|
|
191
|
+
end
|
|
192
|
+
CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrf")
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
# Perform LU solve — dispatched by dtype.
|
|
196
|
+
def perform_getrs(dtype, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
|
|
197
|
+
h = CuSolverBindings.get_handle
|
|
198
|
+
status = case dtype
|
|
199
|
+
when :float32 then CuSolverBindings.cusolverDnSgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
|
|
200
|
+
when :float64 then CuSolverBindings.cusolverDnDgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
|
|
201
|
+
when :complex64 then CuSolverBindings.cusolverDnCgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
|
|
202
|
+
when :complex128 then CuSolverBindings.cusolverDnZgetrs(h, trans, n, nrhs, a_ptr, lda, pivot, b_ptr, ldb, info)
|
|
203
|
+
else raise UnsupportedDTypeError.new(dtype, operation: "LU solve")
|
|
204
|
+
end
|
|
205
|
+
CuSolverBindings.check_status!(status, "cusolverDn#{dtype}getrs")
|
|
206
|
+
end
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
# LU Solver plan for repeated solves with the same matrix
|
|
211
|
+
# Caches the LU factorization for efficiency
|
|
212
|
+
class LUSolver
|
|
213
|
+
# @return [NvArray] Original matrix
|
|
214
|
+
attr_reader :matrix
|
|
215
|
+
|
|
216
|
+
# @return [NvArray] LU factored matrix
|
|
217
|
+
attr_reader :lu
|
|
218
|
+
|
|
219
|
+
# @return [FFI::Pointer] Pivot indices
|
|
220
|
+
attr_reader :pivot
|
|
221
|
+
|
|
222
|
+
# @return [Integer] Matrix dimension
|
|
223
|
+
attr_reader :n
|
|
224
|
+
|
|
225
|
+
# Create an LU solver for the given matrix
|
|
226
|
+
# @param matrix [NvArray] Coefficient matrix (n x n)
|
|
227
|
+
def initialize(matrix)
|
|
228
|
+
@matrix = matrix
|
|
229
|
+
validate!
|
|
230
|
+
@n = matrix.shape[0]
|
|
231
|
+
factorize!
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
# Solve Ax = b using the cached factorization
|
|
235
|
+
# @param b [NvArray] Right-hand side (n x nrhs)
|
|
236
|
+
# @param trans [Symbol] :none, :transpose, or :conjugate
|
|
237
|
+
# @return [NvArray] Solution x
|
|
238
|
+
def solve(b, trans: :none)
|
|
239
|
+
LU.getrs(@lu, b, pivot: @pivot, trans: trans)
|
|
240
|
+
end
|
|
241
|
+
|
|
242
|
+
# Release resources
|
|
243
|
+
# @return [void]
|
|
244
|
+
def destroy!
|
|
245
|
+
@pivot&.free!
|
|
246
|
+
@pivot = nil
|
|
247
|
+
@factorized = false
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
private
|
|
251
|
+
|
|
252
|
+
def validate!
|
|
253
|
+
raise ArgumentError, "Matrix must be square" unless @matrix.shape[0] == @matrix.shape[1]
|
|
254
|
+
raise ArgumentError, "Matrix must be 2D" unless @matrix.shape.length == 2
|
|
255
|
+
end
|
|
256
|
+
|
|
257
|
+
def factorize!
|
|
258
|
+
result = LU.getrf(@matrix.dup)
|
|
259
|
+
@lu = result[:lu]
|
|
260
|
+
@pivot = result[:pivot]
|
|
261
|
+
@factorized = true
|
|
262
|
+
end
|
|
263
|
+
end
|
|
264
|
+
end
|
|
265
|
+
end
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "cudss_bindings"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Solver
|
|
7
|
+
# High-level sparse linear solver using cuDSS
|
|
8
|
+
# Solves sparse systems of the form Ax = b using direct methods
|
|
9
|
+
#
|
|
10
|
+
# cuDSS phases:
|
|
11
|
+
# 1. Analysis: Reordering and symbolic factorization
|
|
12
|
+
# 2. Factorization: Numerical factorization (LU, Cholesky, LDL)
|
|
13
|
+
# 3. Solve: Forward/backward substitution
|
|
14
|
+
#
|
|
15
|
+
# @example Solve a sparse linear system
|
|
16
|
+
# solver = Ignis::Solver::SparseSolver.new(sparse_matrix)
|
|
17
|
+
# solver.analyze!
|
|
18
|
+
# solver.factor!
|
|
19
|
+
# x = solver.solve(b)
|
|
20
|
+
# solver.destroy!
|
|
21
|
+
#
|
|
22
|
+
# @example Quick solve (combines all phases)
|
|
23
|
+
# x = Ignis::Solver::SparseSolver.solve(sparse_matrix, b)
|
|
24
|
+
class SparseSolver
|
|
25
|
+
# Matrix types for cuDSS
|
|
26
|
+
MATRIX_TYPE_GENERAL = 0
|
|
27
|
+
MATRIX_TYPE_SYMMETRIC = 1
|
|
28
|
+
MATRIX_TYPE_HERMITIAN = 2
|
|
29
|
+
MATRIX_TYPE_SPD = 3 # Symmetric Positive Definite
|
|
30
|
+
MATRIX_TYPE_HPD = 4 # Hermitian Positive Definite
|
|
31
|
+
|
|
32
|
+
# Index base
|
|
33
|
+
INDEX_BASE_ZERO = 0
|
|
34
|
+
INDEX_BASE_ONE = 1
|
|
35
|
+
|
|
36
|
+
# CUDA library data types (from library_types.h)
|
|
37
|
+
# cuDSS uses these for both value and index types
|
|
38
|
+
CUDA_R_32F = 0 # float
|
|
39
|
+
CUDA_R_64F = 1 # double
|
|
40
|
+
CUDA_R_16F = 2 # half
|
|
41
|
+
CUDA_C_32F = 4 # cuComplex
|
|
42
|
+
CUDA_C_64F = 5 # cuDoubleComplex
|
|
43
|
+
CUDA_R_32I = 10 # int32
|
|
44
|
+
CUDA_R_64I = 24 # int64
|
|
45
|
+
|
|
46
|
+
# @return [Sparse::SparseMatrix] The sparse coefficient matrix
|
|
47
|
+
attr_reader :matrix
|
|
48
|
+
|
|
49
|
+
# @return [Symbol] Matrix type (:general, :symmetric, :spd)
|
|
50
|
+
attr_reader :matrix_type
|
|
51
|
+
|
|
52
|
+
# @return [Boolean] Whether analysis has been performed
|
|
53
|
+
attr_reader :analyzed
|
|
54
|
+
|
|
55
|
+
# @return [Boolean] Whether factorization has been performed
|
|
56
|
+
attr_reader :factored
|
|
57
|
+
|
|
58
|
+
# Convenience method to solve a sparse system in one call
|
|
59
|
+
# @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix A
|
|
60
|
+
# @param b [NvArray] Right-hand side vector/matrix b
|
|
61
|
+
# @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd)
|
|
62
|
+
# @return [NvArray] Solution vector/matrix x
|
|
63
|
+
def self.solve(sparse_matrix, b, matrix_type: :general)
|
|
64
|
+
solver = new(sparse_matrix, matrix_type: matrix_type)
|
|
65
|
+
begin
|
|
66
|
+
solver.analyze!
|
|
67
|
+
solver.factor!
|
|
68
|
+
solver.solve(b)
|
|
69
|
+
ensure
|
|
70
|
+
solver.destroy!
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Initialize sparse solver
|
|
75
|
+
# @param sparse_matrix [Sparse::SparseMatrix] Sparse coefficient matrix A in CSR format
|
|
76
|
+
# @param matrix_type [Symbol] Type of matrix (:general, :symmetric, :spd, :hermitian, :hpd)
|
|
77
|
+
def initialize(sparse_matrix, matrix_type: :general)
|
|
78
|
+
validate_matrix!(sparse_matrix)
|
|
79
|
+
@matrix = sparse_matrix
|
|
80
|
+
@matrix_type = matrix_type
|
|
81
|
+
@analyzed = false
|
|
82
|
+
@factored = false
|
|
83
|
+
@handle = nil
|
|
84
|
+
@config = nil
|
|
85
|
+
@data = nil
|
|
86
|
+
@matrix_wrapper = nil
|
|
87
|
+
@placeholder_x = nil
|
|
88
|
+
@placeholder_b = nil
|
|
89
|
+
@stream = nil
|
|
90
|
+
|
|
91
|
+
initialize_cudss!
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Perform analysis phase (reordering + symbolic factorization)
|
|
95
|
+
# @return [self]
|
|
96
|
+
def analyze!
|
|
97
|
+
raise StateError, "cuDSS not initialized" unless @handle
|
|
98
|
+
|
|
99
|
+
phase = CuDSSBindings::CUDSS_PHASE_ANALYSIS
|
|
100
|
+
status = CuDSSBindings.cudssExecute(
|
|
101
|
+
@handle, phase, @config, @data, @matrix_wrapper, @placeholder_x, @placeholder_b
|
|
102
|
+
)
|
|
103
|
+
CuDSSBindings.check_status!(status, "cuDSS analysis")
|
|
104
|
+
|
|
105
|
+
@analyzed = true
|
|
106
|
+
self
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Perform numerical factorization
|
|
110
|
+
# @return [self]
|
|
111
|
+
def factor!
|
|
112
|
+
raise StateError, "Must call analyze! before factor!" unless @analyzed
|
|
113
|
+
|
|
114
|
+
phase = CuDSSBindings::CUDSS_PHASE_FACTORIZATION
|
|
115
|
+
status = CuDSSBindings.cudssExecute(
|
|
116
|
+
@handle, phase, @config, @data, @matrix_wrapper, @placeholder_x, @placeholder_b
|
|
117
|
+
)
|
|
118
|
+
CuDSSBindings.check_status!(status, "cuDSS factorization")
|
|
119
|
+
|
|
120
|
+
@factored = true
|
|
121
|
+
self
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# Solve the system Ax = b
|
|
125
|
+
# @param b [NvArray] Right-hand side vector/matrix
|
|
126
|
+
# @return [NvArray] Solution vector/matrix x
|
|
127
|
+
def solve(b)
|
|
128
|
+
raise StateError, "Must call factor! before solve!" unless @factored
|
|
129
|
+
|
|
130
|
+
validate_rhs!(b)
|
|
131
|
+
|
|
132
|
+
# Ensure b is on device
|
|
133
|
+
b_dev = b.on_device? ? b : b.to_device
|
|
134
|
+
|
|
135
|
+
# Create output matrix
|
|
136
|
+
x = NvArray.zeros(b.shape, dtype: b.dtype, device: @matrix.device_index)
|
|
137
|
+
x = x.to_device
|
|
138
|
+
|
|
139
|
+
# Create cuDSS matrix wrappers for b and x
|
|
140
|
+
b_wrapper = create_dense_matrix_wrapper(b_dev)
|
|
141
|
+
x_wrapper = create_dense_matrix_wrapper(x)
|
|
142
|
+
|
|
143
|
+
# Execute solve phase
|
|
144
|
+
phase = CuDSSBindings::CUDSS_PHASE_SOLVE
|
|
145
|
+
status = CuDSSBindings.cudssExecute(
|
|
146
|
+
@handle, phase, @config, @data, @matrix_wrapper, x_wrapper, b_wrapper
|
|
147
|
+
)
|
|
148
|
+
CuDSSBindings.check_status!(status, "cuDSS solve")
|
|
149
|
+
|
|
150
|
+
# Cleanup temporary wrappers
|
|
151
|
+
CuDSSBindings.cudssMatrixDestroy(b_wrapper)
|
|
152
|
+
CuDSSBindings.cudssMatrixDestroy(x_wrapper)
|
|
153
|
+
|
|
154
|
+
x
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Refactor with updated numerical values (same sparsity pattern)
|
|
158
|
+
# Useful when solving multiple systems with same structure but different values
|
|
159
|
+
# @return [self]
|
|
160
|
+
def refactor!
|
|
161
|
+
raise StateError, "Must call factor! before refactor!" unless @factored
|
|
162
|
+
|
|
163
|
+
phase = CuDSSBindings::CUDSS_PHASE_REFACTORIZATION
|
|
164
|
+
status = CuDSSBindings.cudssExecute(
|
|
165
|
+
@handle, phase, @config, @data, @matrix_wrapper, nil, nil
|
|
166
|
+
)
|
|
167
|
+
CuDSSBindings.check_status!(status, "cuDSS refactorization")
|
|
168
|
+
|
|
169
|
+
self
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
# Release all cuDSS resources
|
|
173
|
+
# @return [void]
|
|
174
|
+
def destroy!
|
|
175
|
+
if @placeholder_x
|
|
176
|
+
CuDSSBindings.cudssMatrixDestroy(@placeholder_x)
|
|
177
|
+
@placeholder_x = nil
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
if @placeholder_b
|
|
181
|
+
CuDSSBindings.cudssMatrixDestroy(@placeholder_b)
|
|
182
|
+
@placeholder_b = nil
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
if @matrix_wrapper
|
|
186
|
+
CuDSSBindings.cudssMatrixDestroy(@matrix_wrapper)
|
|
187
|
+
@matrix_wrapper = nil
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
if @data && @handle
|
|
191
|
+
CuDSSBindings.cudssDataDestroy(@handle, @data)
|
|
192
|
+
@data = nil
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
if @config
|
|
196
|
+
CuDSSBindings.cudssConfigDestroy(@config)
|
|
197
|
+
@config = nil
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
if @handle
|
|
201
|
+
CuDSSBindings.cudssDestroy(@handle)
|
|
202
|
+
@handle = nil
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
@analyzed = false
|
|
206
|
+
@factored = false
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
private
|
|
210
|
+
|
|
211
|
+
# Validate input sparse matrix
|
|
212
|
+
def validate_matrix!(matrix)
|
|
213
|
+
unless matrix.is_a?(Sparse::SparseMatrix)
|
|
214
|
+
raise ArgumentError, "Expected SparseMatrix, got #{matrix.class}"
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
unless matrix.format == :csr
|
|
218
|
+
raise ArgumentError, "cuDSS requires CSR format, got #{matrix.format}"
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
unless %i[float32 float64 complex64 complex128].include?(matrix.dtype)
|
|
222
|
+
raise UnsupportedDTypeError.new(matrix.dtype, operation: "sparse solve")
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
# Validate right-hand side
|
|
227
|
+
def validate_rhs!(b)
|
|
228
|
+
unless b.is_a?(NvArray)
|
|
229
|
+
raise ArgumentError, "Expected NvArray, got #{b.class}"
|
|
230
|
+
end
|
|
231
|
+
|
|
232
|
+
n = @matrix.shape[0]
|
|
233
|
+
if b.ndim == 1
|
|
234
|
+
unless b.shape[0] == n
|
|
235
|
+
raise DimensionError, "RHS size (#{b.shape[0]}) doesn't match matrix rows (#{n})"
|
|
236
|
+
end
|
|
237
|
+
elsif b.ndim == 2
|
|
238
|
+
unless b.shape[0] == n
|
|
239
|
+
raise DimensionError, "RHS rows (#{b.shape[0]}) don't match matrix rows (#{n})"
|
|
240
|
+
end
|
|
241
|
+
else
|
|
242
|
+
raise DimensionError, "RHS must be 1D or 2D, got #{b.ndim}D"
|
|
243
|
+
end
|
|
244
|
+
end
|
|
245
|
+
|
|
246
|
+
# Initialize cuDSS handle and create matrix wrapper
|
|
247
|
+
def initialize_cudss!
|
|
248
|
+
CuDSSBindings.ensure_loaded!
|
|
249
|
+
|
|
250
|
+
# Create handle
|
|
251
|
+
handle_ptr = FFI::MemoryPointer.new(:pointer)
|
|
252
|
+
status = CuDSSBindings.cudssCreate(handle_ptr)
|
|
253
|
+
CuDSSBindings.check_status!(status, "cuDSS create")
|
|
254
|
+
@handle = handle_ptr.read_pointer
|
|
255
|
+
|
|
256
|
+
# Set CUDA stream (required before analyze)
|
|
257
|
+
setup_cuda_stream!
|
|
258
|
+
|
|
259
|
+
# Create config and data objects
|
|
260
|
+
@config = create_config_object
|
|
261
|
+
@data = create_data_object
|
|
262
|
+
|
|
263
|
+
# Create sparse matrix wrapper
|
|
264
|
+
@matrix_wrapper = create_csr_matrix_wrapper
|
|
265
|
+
|
|
266
|
+
# Create placeholder solution/rhs wrappers (required for analyze/factor phases)
|
|
267
|
+
create_placeholder_vectors!
|
|
268
|
+
end
|
|
269
|
+
|
|
270
|
+
# Set up CUDA stream for cuDSS operations
|
|
271
|
+
# @return [void]
|
|
272
|
+
def setup_cuda_stream!
|
|
273
|
+
@stream_obj = CUDA::Stream.new
|
|
274
|
+
@stream = @stream_obj.handle
|
|
275
|
+
|
|
276
|
+
status = CuDSSBindings.cudssSetStream(@handle, @stream)
|
|
277
|
+
CuDSSBindings.check_status!(status, "cuDSS set stream")
|
|
278
|
+
end
|
|
279
|
+
|
|
280
|
+
# Create placeholder solution/rhs vectors for analyze/factor phases
|
|
281
|
+
# cuDSS requires valid matrix wrappers even if values aren't used during these phases
|
|
282
|
+
# @return [void]
|
|
283
|
+
def create_placeholder_vectors!
|
|
284
|
+
n = @matrix.shape[0]
|
|
285
|
+
value_type = cudss_data_type(@matrix.dtype)
|
|
286
|
+
layout = 0 # CUDSS_LAYOUT_COL_MAJOR (0 = column major, 1 = row major not supported)
|
|
287
|
+
|
|
288
|
+
# Allocate device memory for placeholder vectors
|
|
289
|
+
elem_size = @matrix.dtype == :float64 ? 8 : 4
|
|
290
|
+
@placeholder_x_mem = CUDA::Memory.new(n * elem_size)
|
|
291
|
+
@placeholder_b_mem = CUDA::Memory.new(n * elem_size)
|
|
292
|
+
|
|
293
|
+
# Create dense matrix wrappers for x (solution) and b (rhs)
|
|
294
|
+
x_ptr = FFI::MemoryPointer.new(:pointer)
|
|
295
|
+
status = CuDSSBindings.cudssMatrixCreateDn(
|
|
296
|
+
x_ptr, n, 1, n, @placeholder_x_mem.device_ptr, value_type, layout
|
|
297
|
+
)
|
|
298
|
+
CuDSSBindings.check_status!(status, "cuDSS create placeholder x")
|
|
299
|
+
@placeholder_x = x_ptr.read_pointer
|
|
300
|
+
|
|
301
|
+
b_ptr = FFI::MemoryPointer.new(:pointer)
|
|
302
|
+
status = CuDSSBindings.cudssMatrixCreateDn(
|
|
303
|
+
b_ptr, n, 1, n, @placeholder_b_mem.device_ptr, value_type, layout
|
|
304
|
+
)
|
|
305
|
+
CuDSSBindings.check_status!(status, "cuDSS create placeholder b")
|
|
306
|
+
@placeholder_b = b_ptr.read_pointer
|
|
307
|
+
end
|
|
308
|
+
|
|
309
|
+
# Create cuDSS config object
|
|
310
|
+
# @return [FFI::Pointer]
|
|
311
|
+
def create_config_object
|
|
312
|
+
config_ptr = FFI::MemoryPointer.new(:pointer)
|
|
313
|
+
|
|
314
|
+
status = CuDSSBindings.cudssConfigCreate(config_ptr)
|
|
315
|
+
CuDSSBindings.check_status!(status, "cuDSS config create")
|
|
316
|
+
|
|
317
|
+
config_ptr.read_pointer
|
|
318
|
+
end
|
|
319
|
+
|
|
320
|
+
# Create cuDSS data object
|
|
321
|
+
# @return [FFI::Pointer]
|
|
322
|
+
def create_data_object
|
|
323
|
+
data_ptr = FFI::MemoryPointer.new(:pointer)
|
|
324
|
+
|
|
325
|
+
status = CuDSSBindings.cudssDataCreate(@handle, data_ptr)
|
|
326
|
+
CuDSSBindings.check_status!(status, "cuDSS data create")
|
|
327
|
+
|
|
328
|
+
data_ptr.read_pointer
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
# Create CSR matrix wrapper for cuDSS
|
|
332
|
+
# @return [FFI::Pointer]
|
|
333
|
+
def create_csr_matrix_wrapper
|
|
334
|
+
matrix_ptr = FFI::MemoryPointer.new(:pointer)
|
|
335
|
+
|
|
336
|
+
n_rows = @matrix.shape[0]
|
|
337
|
+
n_cols = @matrix.shape[1]
|
|
338
|
+
nnz = @matrix.nnz
|
|
339
|
+
|
|
340
|
+
# Ensure CUDA context is initialized (cuDSS requires active context)
|
|
341
|
+
Ignis.synchronize
|
|
342
|
+
|
|
343
|
+
# Ensure matrix is on device - cuDSS requires device pointers
|
|
344
|
+
@matrix.to_device unless @matrix.on_device?
|
|
345
|
+
|
|
346
|
+
# Verify we have valid device pointers
|
|
347
|
+
raise StateError, "Matrix not on device" unless @matrix.on_device?
|
|
348
|
+
raise StateError, "row_ptr is nil" if @matrix.row_ptr.nil?
|
|
349
|
+
raise StateError, "col_indices is nil" if @matrix.col_indices.nil?
|
|
350
|
+
raise StateError, "values is nil" if @matrix.values.nil?
|
|
351
|
+
|
|
352
|
+
# Get cuDSS data type
|
|
353
|
+
value_type = cudss_data_type(@matrix.dtype)
|
|
354
|
+
index_type = CUDA_R_32I # int32 indices
|
|
355
|
+
|
|
356
|
+
# Matrix structure type
|
|
357
|
+
structure = case @matrix_type
|
|
358
|
+
when :general then MATRIX_TYPE_GENERAL
|
|
359
|
+
when :symmetric then MATRIX_TYPE_SYMMETRIC
|
|
360
|
+
when :hermitian then MATRIX_TYPE_HERMITIAN
|
|
361
|
+
when :spd then MATRIX_TYPE_SPD
|
|
362
|
+
when :hpd then MATRIX_TYPE_HPD
|
|
363
|
+
else MATRIX_TYPE_GENERAL
|
|
364
|
+
end
|
|
365
|
+
|
|
366
|
+
# mview: CUDSS_MVIEW_FULL = 0
|
|
367
|
+
mview_full = 0
|
|
368
|
+
|
|
369
|
+
status = CuDSSBindings.cudssMatrixCreateCsr(
|
|
370
|
+
matrix_ptr,
|
|
371
|
+
n_rows, n_cols, nnz,
|
|
372
|
+
@matrix.row_ptr.device_ptr,
|
|
373
|
+
FFI::Pointer::NULL, # row_end - NULL for standard CSR
|
|
374
|
+
@matrix.col_indices.device_ptr,
|
|
375
|
+
@matrix.values.device_ptr,
|
|
376
|
+
index_type,
|
|
377
|
+
value_type,
|
|
378
|
+
structure,
|
|
379
|
+
mview_full, # mview - full matrix view
|
|
380
|
+
INDEX_BASE_ZERO # indexBase - zero-based indexing
|
|
381
|
+
)
|
|
382
|
+
CuDSSBindings.check_status!(status, "cuDSS CSR matrix create")
|
|
383
|
+
|
|
384
|
+
matrix_ptr.read_pointer
|
|
385
|
+
end
|
|
386
|
+
|
|
387
|
+
# Create dense matrix wrapper for cuDSS
|
|
388
|
+
# @return [FFI::Pointer]
|
|
389
|
+
def create_dense_matrix_wrapper(array)
|
|
390
|
+
matrix_ptr = FFI::MemoryPointer.new(:pointer)
|
|
391
|
+
|
|
392
|
+
n_rows = array.ndim == 1 ? array.shape[0] : array.shape[0]
|
|
393
|
+
n_cols = array.ndim == 1 ? 1 : array.shape[1]
|
|
394
|
+
ld = n_rows # Leading dimension (column-major)
|
|
395
|
+
|
|
396
|
+
value_type = cudss_data_type(array.dtype)
|
|
397
|
+
layout = 0 # CUDSS_LAYOUT_COL_MAJOR (0 = column major, 1 = row major not supported)
|
|
398
|
+
|
|
399
|
+
status = CuDSSBindings.cudssMatrixCreateDn(
|
|
400
|
+
matrix_ptr,
|
|
401
|
+
n_rows, n_cols, ld,
|
|
402
|
+
array.device_ptr,
|
|
403
|
+
value_type,
|
|
404
|
+
layout
|
|
405
|
+
)
|
|
406
|
+
CuDSSBindings.check_status!(status, "cuDSS dense matrix create")
|
|
407
|
+
|
|
408
|
+
matrix_ptr.read_pointer
|
|
409
|
+
end
|
|
410
|
+
|
|
411
|
+
# Convert Ignis dtype to CUDA library data type
|
|
412
|
+
# @param dtype [Symbol]
|
|
413
|
+
# @return [Integer]
|
|
414
|
+
def cudss_data_type(dtype)
|
|
415
|
+
case dtype
|
|
416
|
+
when :float32 then CUDA_R_32F
|
|
417
|
+
when :float64 then CUDA_R_64F
|
|
418
|
+
when :complex64 then CUDA_C_32F
|
|
419
|
+
when :complex128 then CUDA_C_64F
|
|
420
|
+
else
|
|
421
|
+
raise UnsupportedDTypeError.new(dtype, operation: "cuDSS")
|
|
422
|
+
end
|
|
423
|
+
end
|
|
424
|
+
end
|
|
425
|
+
|
|
426
|
+
# State error for solver operations
|
|
427
|
+
class StateError < StandardError; end
|
|
428
|
+
end
|
|
429
|
+
end
|