ignis-numerics 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-numerics.rb +62 -0
- data/lib/nvruby/array.rb +646 -0
- data/lib/nvruby/fft/cufft_bindings.rb +134 -0
- data/lib/nvruby/fft/fft_plan.rb +288 -0
- data/lib/nvruby/fft/operations.rb +364 -0
- data/lib/nvruby/linalg/cutensor_bindings.rb +107 -0
- data/lib/nvruby/mathdx/fft_kernel.rb +258 -0
- data/lib/nvruby/mathdx/gemm_kernel.rb +293 -0
- data/lib/nvruby/mathdx.rb +73 -0
- data/lib/nvruby/random/curand_bindings.rb +115 -0
- data/lib/nvruby/random/generator.rb +305 -0
- data/lib/nvruby/solver/amgx_bindings.rb +172 -0
- data/lib/nvruby/solver/amgx_config.rb +142 -0
- data/lib/nvruby/solver/amgx_solver.rb +251 -0
- data/lib/nvruby/solver/cudss_bindings.rb +115 -0
- data/lib/nvruby/solver/cusolver_bindings.rb +358 -0
- data/lib/nvruby/solver/eigen.rb +226 -0
- data/lib/nvruby/solver/lu.rb +265 -0
- data/lib/nvruby/solver/sparse_solver.rb +429 -0
- data/lib/nvruby/solver/svd.rb +266 -0
- data/lib/nvruby/solver.rb +122 -0
- data/lib/nvruby/sparse/cusparse_bindings.rb +231 -0
- data/lib/nvruby/sparse/sparse_matrix.rb +456 -0
- data/lib/nvruby/tensor/contraction.rb +218 -0
- data/lib/nvruby/tensor.rb +42 -0
- metadata +85 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module Solver
|
|
5
|
+
# Configuration builder for AMGX solver
|
|
6
|
+
# Provides preset configurations and custom config strings
|
|
7
|
+
class AMGXConfig
|
|
8
|
+
# Predefined solver configurations
|
|
9
|
+
PRESETS = {
|
|
10
|
+
# Classical AMG with V-cycle, Jacobi smoother
|
|
11
|
+
classical_v_cycle: <<~CONFIG,
|
|
12
|
+
config_version=2,
|
|
13
|
+
solver(main)=AMG,
|
|
14
|
+
main:algorithm=CLASSICAL,
|
|
15
|
+
main:selector=PMIS,
|
|
16
|
+
main:interpolator=D2,
|
|
17
|
+
main:presweeps=2,
|
|
18
|
+
main:postsweeps=2,
|
|
19
|
+
main:max_iters=100,
|
|
20
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
21
|
+
main:tolerance=1e-8,
|
|
22
|
+
main:norm=L2,
|
|
23
|
+
main:cycle=V,
|
|
24
|
+
main:smoother=BLOCK_JACOBI,
|
|
25
|
+
main:coarse_solver=DENSE_LU_SOLVER,
|
|
26
|
+
main:max_levels=20
|
|
27
|
+
CONFIG
|
|
28
|
+
|
|
29
|
+
# Aggregation AMG with V-cycle
|
|
30
|
+
aggregation_v_cycle: <<~CONFIG,
|
|
31
|
+
config_version=2,
|
|
32
|
+
solver(main)=AMG,
|
|
33
|
+
main:algorithm=AGGREGATION,
|
|
34
|
+
main:selector=SIZE_2,
|
|
35
|
+
main:presweeps=2,
|
|
36
|
+
main:postsweeps=2,
|
|
37
|
+
main:max_iters=100,
|
|
38
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
39
|
+
main:tolerance=1e-8,
|
|
40
|
+
main:norm=L2,
|
|
41
|
+
main:cycle=V,
|
|
42
|
+
main:smoother=MULTICOLOR_GS,
|
|
43
|
+
main:coarse_solver=DENSE_LU_SOLVER
|
|
44
|
+
CONFIG
|
|
45
|
+
|
|
46
|
+
# PCG with AMG preconditioner
|
|
47
|
+
pcg_amg: <<~CONFIG,
|
|
48
|
+
config_version=2,
|
|
49
|
+
solver(main)=PCG,
|
|
50
|
+
main:preconditioner(amg)=AMG,
|
|
51
|
+
main:max_iters=1000,
|
|
52
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
53
|
+
main:tolerance=1e-10,
|
|
54
|
+
main:norm=L2,
|
|
55
|
+
amg:algorithm=CLASSICAL,
|
|
56
|
+
amg:max_iters=1,
|
|
57
|
+
amg:cycle=V,
|
|
58
|
+
amg:smoother=BLOCK_JACOBI
|
|
59
|
+
CONFIG
|
|
60
|
+
|
|
61
|
+
# PBICGSTAB with AMG preconditioner
|
|
62
|
+
pbicgstab_amg: <<~CONFIG,
|
|
63
|
+
config_version=2,
|
|
64
|
+
solver(main)=PBICGSTAB,
|
|
65
|
+
main:preconditioner(amg)=AMG,
|
|
66
|
+
main:max_iters=1000,
|
|
67
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
68
|
+
main:tolerance=1e-10,
|
|
69
|
+
main:norm=L2,
|
|
70
|
+
amg:algorithm=AGGREGATION,
|
|
71
|
+
amg:max_iters=1,
|
|
72
|
+
amg:cycle=V
|
|
73
|
+
CONFIG
|
|
74
|
+
|
|
75
|
+
# Simple Jacobi iteration
|
|
76
|
+
jacobi: <<~CONFIG,
|
|
77
|
+
config_version=2,
|
|
78
|
+
solver(main)=BLOCK_JACOBI,
|
|
79
|
+
main:max_iters=1000,
|
|
80
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
81
|
+
main:tolerance=1e-8
|
|
82
|
+
CONFIG
|
|
83
|
+
|
|
84
|
+
# Gauss-Seidel
|
|
85
|
+
gauss_seidel: <<~CONFIG
|
|
86
|
+
config_version=2,
|
|
87
|
+
solver(main)=MULTICOLOR_GS,
|
|
88
|
+
main:max_iters=1000,
|
|
89
|
+
main:convergence=RELATIVE_INI_CORE,
|
|
90
|
+
main:tolerance=1e-8
|
|
91
|
+
CONFIG
|
|
92
|
+
}.freeze
|
|
93
|
+
|
|
94
|
+
# @return [String] Configuration string
|
|
95
|
+
attr_reader :config_string
|
|
96
|
+
|
|
97
|
+
# Create configuration from preset or custom string
|
|
98
|
+
# @param preset [Symbol, nil] Preset name (:classical_v_cycle, :aggregation_v_cycle, etc.)
|
|
99
|
+
# @param custom [String, nil] Custom configuration string
|
|
100
|
+
# @param options [Hash] Override options
|
|
101
|
+
def initialize(preset: nil, custom: nil, **options)
|
|
102
|
+
if custom
|
|
103
|
+
@config_string = custom
|
|
104
|
+
elsif preset
|
|
105
|
+
raise ArgumentError, "Unknown preset: #{preset}" unless PRESETS.key?(preset)
|
|
106
|
+
|
|
107
|
+
@config_string = PRESETS[preset].gsub(/\s+/, "")
|
|
108
|
+
else
|
|
109
|
+
@config_string = PRESETS[:classical_v_cycle].gsub(/\s+/, "")
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
apply_options!(options) unless options.empty?
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
# @return [String]
|
|
116
|
+
def to_s
|
|
117
|
+
@config_string
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# List available presets
|
|
121
|
+
# @return [Array<Symbol>]
|
|
122
|
+
def self.presets
|
|
123
|
+
PRESETS.keys
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
private
|
|
127
|
+
|
|
128
|
+
def apply_options!(options)
|
|
129
|
+
options.each do |key, value|
|
|
130
|
+
pattern = /#{key}=[^,]+/
|
|
131
|
+
replacement = "#{key}=#{value}"
|
|
132
|
+
|
|
133
|
+
if @config_string.match?(pattern)
|
|
134
|
+
@config_string = @config_string.gsub(pattern, replacement)
|
|
135
|
+
else
|
|
136
|
+
@config_string += ",main:#{key}=#{value}"
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
end
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative 'amgx_bindings'
|
|
4
|
+
require_relative 'amgx_config'
|
|
5
|
+
|
|
6
|
+
module Ignis
|
|
7
|
+
module Solver
|
|
8
|
+
# GPU-accelerated algebraic multigrid solver using NVIDIA AMGX
|
|
9
|
+
# Solves sparse linear systems Ax = b
|
|
10
|
+
class AMGXSolver
|
|
11
|
+
# @return [AMGXConfig] Solver configuration
|
|
12
|
+
attr_reader :config
|
|
13
|
+
|
|
14
|
+
# @return [Symbol] Precision mode (:double or :float)
|
|
15
|
+
attr_reader :precision
|
|
16
|
+
|
|
17
|
+
# @return [Integer] Number of iterations from last solve
|
|
18
|
+
attr_reader :iterations
|
|
19
|
+
|
|
20
|
+
# @return [Float] Final residual from last solve
|
|
21
|
+
attr_reader :residual
|
|
22
|
+
|
|
23
|
+
# @param config [Symbol, String, AMGXConfig] Configuration (preset symbol, string, or AMGXConfig)
|
|
24
|
+
# @param precision [Symbol] :double or :float
|
|
25
|
+
# @param device [Boolean] Use GPU (true) or CPU (false)
|
|
26
|
+
def initialize(config: :classical_v_cycle, precision: :double, device: true)
|
|
27
|
+
AMGXBindings.ensure_loaded!
|
|
28
|
+
|
|
29
|
+
@config = case config
|
|
30
|
+
when Symbol
|
|
31
|
+
AMGXConfig.new(preset: config)
|
|
32
|
+
when String
|
|
33
|
+
AMGXConfig.new(custom: config)
|
|
34
|
+
when AMGXConfig
|
|
35
|
+
config
|
|
36
|
+
else
|
|
37
|
+
raise ArgumentError, "Invalid config type: #{config.class}"
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
@precision = precision
|
|
41
|
+
@device = device
|
|
42
|
+
@mode = determine_mode
|
|
43
|
+
@initialized = false
|
|
44
|
+
@iterations = 0
|
|
45
|
+
@residual = 0.0
|
|
46
|
+
|
|
47
|
+
initialize_amgx!
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Solve sparse linear system Ax = b
|
|
51
|
+
# @param a [Hash] Sparse matrix in CSR format { row_ptr:, col_idx:, values:, n:, nnz: }
|
|
52
|
+
# @param b [Array<Float>, NvArray] Right-hand side vector
|
|
53
|
+
# @param x0 [Array<Float>, NvArray, nil] Initial guess (nil for zero)
|
|
54
|
+
# @return [Array<Float>] Solution vector
|
|
55
|
+
def solve(a, b, x0: nil)
|
|
56
|
+
raise AMGXError, "Solver not initialized" unless @initialized
|
|
57
|
+
|
|
58
|
+
n = a[:n]
|
|
59
|
+
nnz = a[:nnz]
|
|
60
|
+
|
|
61
|
+
upload_matrix(a)
|
|
62
|
+
upload_vector(@rhs_handle, b)
|
|
63
|
+
|
|
64
|
+
if x0
|
|
65
|
+
upload_vector(@sol_handle, x0)
|
|
66
|
+
rc = AMGXBindings.AMGX_solver_solve(@solver_handle, @rhs_handle, @sol_handle)
|
|
67
|
+
else
|
|
68
|
+
rc = AMGXBindings.AMGX_solver_solve_with_0_initial_guess(
|
|
69
|
+
@solver_handle, @rhs_handle, @sol_handle
|
|
70
|
+
)
|
|
71
|
+
end
|
|
72
|
+
AMGXBindings.check_rc!(rc, "AMGX_solver_solve")
|
|
73
|
+
|
|
74
|
+
update_solve_stats!
|
|
75
|
+
download_solution(n)
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
# Get solve status
|
|
79
|
+
# @return [Symbol] :success, :failed, :diverged, or :not_converged
|
|
80
|
+
def status
|
|
81
|
+
status_ptr = FFI::MemoryPointer.new(:int)
|
|
82
|
+
AMGXBindings.AMGX_solver_get_status(@solver_handle, status_ptr)
|
|
83
|
+
|
|
84
|
+
case status_ptr.read_int
|
|
85
|
+
when AMGXBindings::SolveStatus::SUCCESS then :success
|
|
86
|
+
when AMGXBindings::SolveStatus::FAILED then :failed
|
|
87
|
+
when AMGXBindings::SolveStatus::DIVERGED then :diverged
|
|
88
|
+
when AMGXBindings::SolveStatus::NOT_CONVERGED then :not_converged
|
|
89
|
+
else :unknown
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
# Release all AMGX resources
|
|
94
|
+
# @return [void]
|
|
95
|
+
def destroy!
|
|
96
|
+
return unless @initialized
|
|
97
|
+
|
|
98
|
+
AMGXBindings.AMGX_solver_destroy(@solver_handle) if @solver_handle
|
|
99
|
+
AMGXBindings.AMGX_vector_destroy(@sol_handle) if @sol_handle
|
|
100
|
+
AMGXBindings.AMGX_vector_destroy(@rhs_handle) if @rhs_handle
|
|
101
|
+
AMGXBindings.AMGX_matrix_destroy(@matrix_handle) if @matrix_handle
|
|
102
|
+
AMGXBindings.AMGX_resources_destroy(@resources_handle) if @resources_handle
|
|
103
|
+
AMGXBindings.AMGX_config_destroy(@config_handle) if @config_handle
|
|
104
|
+
AMGXBindings.AMGX_finalize
|
|
105
|
+
|
|
106
|
+
@initialized = false
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
private
|
|
110
|
+
|
|
111
|
+
def determine_mode
|
|
112
|
+
if @device
|
|
113
|
+
@precision == :double ? AMGXBindings::Mode::DEVICE_DDI : AMGXBindings::Mode::DEVICE_FFI
|
|
114
|
+
else
|
|
115
|
+
@precision == :double ? AMGXBindings::Mode::HOST_DDI : AMGXBindings::Mode::HOST_FFI
|
|
116
|
+
end
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def initialize_amgx!
|
|
120
|
+
rc = AMGXBindings.AMGX_initialize
|
|
121
|
+
AMGXBindings.check_rc!(rc, "AMGX_initialize")
|
|
122
|
+
|
|
123
|
+
@config_handle = FFI::MemoryPointer.new(:pointer)
|
|
124
|
+
rc = AMGXBindings.AMGX_config_create(@config_handle, @config.to_s)
|
|
125
|
+
AMGXBindings.check_rc!(rc, "AMGX_config_create")
|
|
126
|
+
@config_handle = @config_handle.read_pointer
|
|
127
|
+
|
|
128
|
+
@resources_handle = FFI::MemoryPointer.new(:pointer)
|
|
129
|
+
rc = AMGXBindings.AMGX_resources_create_simple(@resources_handle, @config_handle)
|
|
130
|
+
AMGXBindings.check_rc!(rc, "AMGX_resources_create_simple")
|
|
131
|
+
@resources_handle = @resources_handle.read_pointer
|
|
132
|
+
|
|
133
|
+
@matrix_handle = FFI::MemoryPointer.new(:pointer)
|
|
134
|
+
rc = AMGXBindings.AMGX_matrix_create(@matrix_handle, @resources_handle, @mode)
|
|
135
|
+
AMGXBindings.check_rc!(rc, "AMGX_matrix_create")
|
|
136
|
+
@matrix_handle = @matrix_handle.read_pointer
|
|
137
|
+
|
|
138
|
+
@rhs_handle = FFI::MemoryPointer.new(:pointer)
|
|
139
|
+
rc = AMGXBindings.AMGX_vector_create(@rhs_handle, @resources_handle, @mode)
|
|
140
|
+
AMGXBindings.check_rc!(rc, "AMGX_vector_create (rhs)")
|
|
141
|
+
@rhs_handle = @rhs_handle.read_pointer
|
|
142
|
+
|
|
143
|
+
@sol_handle = FFI::MemoryPointer.new(:pointer)
|
|
144
|
+
rc = AMGXBindings.AMGX_vector_create(@sol_handle, @resources_handle, @mode)
|
|
145
|
+
AMGXBindings.check_rc!(rc, "AMGX_vector_create (sol)")
|
|
146
|
+
@sol_handle = @sol_handle.read_pointer
|
|
147
|
+
|
|
148
|
+
@solver_handle = FFI::MemoryPointer.new(:pointer)
|
|
149
|
+
rc = AMGXBindings.AMGX_solver_create(@solver_handle, @resources_handle, @mode, @config_handle)
|
|
150
|
+
AMGXBindings.check_rc!(rc, "AMGX_solver_create")
|
|
151
|
+
@solver_handle = @solver_handle.read_pointer
|
|
152
|
+
|
|
153
|
+
@initialized = true
|
|
154
|
+
@matrix_uploaded = false
|
|
155
|
+
|
|
156
|
+
ObjectSpace.define_finalizer(self, self.class.release_finalizer(
|
|
157
|
+
@solver_handle, @sol_handle, @rhs_handle, @matrix_handle,
|
|
158
|
+
@resources_handle, @config_handle
|
|
159
|
+
))
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
def upload_matrix(a)
|
|
163
|
+
n = a[:n]
|
|
164
|
+
nnz = a[:nnz]
|
|
165
|
+
row_ptr = a[:row_ptr]
|
|
166
|
+
col_idx = a[:col_idx]
|
|
167
|
+
values = a[:values]
|
|
168
|
+
|
|
169
|
+
row_ptr_buf = FFI::MemoryPointer.new(:int, row_ptr.length)
|
|
170
|
+
row_ptr_buf.write_array_of_int(row_ptr)
|
|
171
|
+
|
|
172
|
+
col_idx_buf = FFI::MemoryPointer.new(:int, col_idx.length)
|
|
173
|
+
col_idx_buf.write_array_of_int(col_idx)
|
|
174
|
+
|
|
175
|
+
if @precision == :double
|
|
176
|
+
values_buf = FFI::MemoryPointer.new(:double, values.length)
|
|
177
|
+
values_buf.write_array_of_double(values)
|
|
178
|
+
else
|
|
179
|
+
values_buf = FFI::MemoryPointer.new(:float, values.length)
|
|
180
|
+
values_buf.write_array_of_float(values)
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
rc = AMGXBindings.AMGX_matrix_upload_all(
|
|
184
|
+
@matrix_handle, n, nnz, 1, 1,
|
|
185
|
+
row_ptr_buf, col_idx_buf, values_buf, FFI::Pointer::NULL
|
|
186
|
+
)
|
|
187
|
+
AMGXBindings.check_rc!(rc, "AMGX_matrix_upload_all")
|
|
188
|
+
|
|
189
|
+
unless @matrix_uploaded
|
|
190
|
+
rc = AMGXBindings.AMGX_solver_setup(@solver_handle, @matrix_handle)
|
|
191
|
+
AMGXBindings.check_rc!(rc, "AMGX_solver_setup")
|
|
192
|
+
@matrix_uploaded = true
|
|
193
|
+
end
|
|
194
|
+
end
|
|
195
|
+
|
|
196
|
+
def upload_vector(handle, data)
|
|
197
|
+
data = data.to_a if data.respond_to?(:to_a)
|
|
198
|
+
n = data.length
|
|
199
|
+
|
|
200
|
+
if @precision == :double
|
|
201
|
+
buf = FFI::MemoryPointer.new(:double, n)
|
|
202
|
+
buf.write_array_of_double(data)
|
|
203
|
+
else
|
|
204
|
+
buf = FFI::MemoryPointer.new(:float, n)
|
|
205
|
+
buf.write_array_of_float(data)
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
rc = AMGXBindings.AMGX_vector_upload(handle, n, 1, buf)
|
|
209
|
+
AMGXBindings.check_rc!(rc, "AMGX_vector_upload")
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
def download_solution(n)
|
|
213
|
+
if @precision == :double
|
|
214
|
+
buf = FFI::MemoryPointer.new(:double, n)
|
|
215
|
+
rc = AMGXBindings.AMGX_vector_download(@sol_handle, buf)
|
|
216
|
+
AMGXBindings.check_rc!(rc, "AMGX_vector_download")
|
|
217
|
+
buf.read_array_of_double(n)
|
|
218
|
+
else
|
|
219
|
+
buf = FFI::MemoryPointer.new(:float, n)
|
|
220
|
+
rc = AMGXBindings.AMGX_vector_download(@sol_handle, buf)
|
|
221
|
+
AMGXBindings.check_rc!(rc, "AMGX_vector_download")
|
|
222
|
+
buf.read_array_of_float(n)
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
def update_solve_stats!
|
|
227
|
+
iter_ptr = FFI::MemoryPointer.new(:int)
|
|
228
|
+
AMGXBindings.AMGX_solver_get_iterations_number(@solver_handle, iter_ptr)
|
|
229
|
+
@iterations = iter_ptr.read_int
|
|
230
|
+
|
|
231
|
+
res_ptr = FFI::MemoryPointer.new(:double)
|
|
232
|
+
AMGXBindings.AMGX_solver_get_iteration_residual(@solver_handle, @iterations - 1, 0, res_ptr)
|
|
233
|
+
@residual = res_ptr.read_double
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
class << self
|
|
237
|
+
def release_finalizer(solver, sol, rhs, matrix, resources, config)
|
|
238
|
+
proc do
|
|
239
|
+
AMGXBindings.AMGX_solver_destroy(solver) if solver && !solver.null?
|
|
240
|
+
AMGXBindings.AMGX_vector_destroy(sol) if sol && !sol.null?
|
|
241
|
+
AMGXBindings.AMGX_vector_destroy(rhs) if rhs && !rhs.null?
|
|
242
|
+
AMGXBindings.AMGX_matrix_destroy(matrix) if matrix && !matrix.null?
|
|
243
|
+
AMGXBindings.AMGX_resources_destroy(resources) if resources && !resources.null?
|
|
244
|
+
AMGXBindings.AMGX_config_destroy(config) if config && !config.null?
|
|
245
|
+
AMGXBindings.AMGX_finalize
|
|
246
|
+
end
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
end
|
|
250
|
+
end
|
|
251
|
+
end
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "ffi"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Solver
|
|
7
|
+
# FFI bindings for NVIDIA cuDSS v0.7+
|
|
8
|
+
# High-performance direct sparse solvers for CUDA
|
|
9
|
+
module CuDSSBindings
|
|
10
|
+
extend FFI::Library
|
|
11
|
+
|
|
12
|
+
# cuDSS status codes
|
|
13
|
+
CUDSS_STATUS_SUCCESS = 0
|
|
14
|
+
CUDSS_STATUS_NOT_INITIALIZED = 1
|
|
15
|
+
CUDSS_STATUS_ALLOC_FAILED = 2
|
|
16
|
+
CUDSS_STATUS_INVALID_VALUE = 3
|
|
17
|
+
CUDSS_STATUS_ARCH_MISMATCH = 4
|
|
18
|
+
CUDSS_STATUS_MAPPING_ERROR = 5
|
|
19
|
+
CUDSS_STATUS_EXECUTION_FAILED = 6
|
|
20
|
+
CUDSS_STATUS_INTERNAL_ERROR = 7
|
|
21
|
+
CUDSS_STATUS_NOT_SUPPORTED = 8
|
|
22
|
+
|
|
23
|
+
# cuDSS Phases (from cudss.h - these are bitmask values)
|
|
24
|
+
CUDSS_PHASE_REORDERING = 1 # 1 << 0
|
|
25
|
+
CUDSS_PHASE_SYMBOLIC_FACTORIZATION = 2 # 1 << 1
|
|
26
|
+
CUDSS_PHASE_ANALYSIS = 3 # REORDERING | SYMBOLIC_FACTORIZATION = 1 | 2 = 3
|
|
27
|
+
CUDSS_PHASE_FACTORIZATION = 4 # 1 << 2
|
|
28
|
+
CUDSS_PHASE_REFACTORIZATION = 8 # 1 << 3
|
|
29
|
+
CUDSS_PHASE_SOLVE_FWD_PERM = 16 # 1 << 4
|
|
30
|
+
CUDSS_PHASE_SOLVE_FWD = 32 # 1 << 5
|
|
31
|
+
CUDSS_PHASE_SOLVE_DIAG = 64 # 1 << 6
|
|
32
|
+
CUDSS_PHASE_SOLVE_BWD = 128 # 1 << 7
|
|
33
|
+
CUDSS_PHASE_SOLVE_BWD_PERM = 256 # 1 << 8
|
|
34
|
+
CUDSS_PHASE_SOLVE_REFINEMENT = 512 # 1 << 9
|
|
35
|
+
CUDSS_PHASE_SOLVE = 1008 # All solve phases combined (16|32|64|128|256|512)
|
|
36
|
+
|
|
37
|
+
@loaded = false
|
|
38
|
+
@mutex = Mutex.new
|
|
39
|
+
|
|
40
|
+
class << self
|
|
41
|
+
# Ensure cuDSS library is loaded and functions attached
|
|
42
|
+
# @return [void]
|
|
43
|
+
def ensure_loaded!
|
|
44
|
+
@mutex.synchronize do
|
|
45
|
+
return if @loaded
|
|
46
|
+
|
|
47
|
+
CUDA::LibraryLoader.load_library(:cudss)
|
|
48
|
+
dll_path = CUDA::LibraryLoader.library_paths[:cudss]
|
|
49
|
+
|
|
50
|
+
raise LibraryNotFoundError, "cudss" unless dll_path
|
|
51
|
+
|
|
52
|
+
ffi_lib dll_path
|
|
53
|
+
attach_functions!
|
|
54
|
+
|
|
55
|
+
@loaded = true
|
|
56
|
+
Ignis.logger.debug("cuDSS bindings loaded from #{dll_path}")
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Check status and raise error if not success
|
|
61
|
+
# @param status [Integer] status code
|
|
62
|
+
# @param context [String] Context for error message
|
|
63
|
+
# @raise [CuDSSError] If status is not success
|
|
64
|
+
def check_status!(status, context = "cuDSS operation")
|
|
65
|
+
return if status == CUDSS_STATUS_SUCCESS
|
|
66
|
+
|
|
67
|
+
raise CuDSSError.new("#{context} failed", status_code: status)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
private
|
|
71
|
+
|
|
72
|
+
def attach_functions!
|
|
73
|
+
# Handle management
|
|
74
|
+
attach_function :cudssCreate, [:pointer], :int
|
|
75
|
+
attach_function :cudssDestroy, [:pointer], :int
|
|
76
|
+
|
|
77
|
+
# Stream management
|
|
78
|
+
attach_function :cudssSetStream, [:pointer, :pointer], :int
|
|
79
|
+
|
|
80
|
+
# Config and Data management
|
|
81
|
+
attach_function :cudssConfigSet, [:pointer, :int, :pointer, :size_t], :int
|
|
82
|
+
attach_function :cudssConfigGet, [:pointer, :int, :pointer, :size_t], :int
|
|
83
|
+
|
|
84
|
+
# Matrix and Data Object creation (these are usually wrappers for raw pointers)
|
|
85
|
+
# cudssMatrixCreateDn: matrixPtr, nRows, nCols, ld, values, valueType, order
|
|
86
|
+
attach_function :cudssMatrixCreateDn, [:pointer, :int64, :int64, :int64, :pointer, :int, :int], :int
|
|
87
|
+
# cudssMatrixCreateCsr: matrixPtr, nRows, nCols, nnz, rowStart, rowEnd, colIndices, values, indexType, valueType, mType, mView, indexBase
|
|
88
|
+
attach_function :cudssMatrixCreateCsr, [:pointer, :int64, :int64, :int64, :pointer, :pointer, :pointer, :pointer, :int, :int, :int, :int, :int], :int
|
|
89
|
+
attach_function :cudssMatrixDestroy, [:pointer], :int
|
|
90
|
+
|
|
91
|
+
# Config object: cudssConfigCreate(config) -> 1 output pointer
|
|
92
|
+
attach_function :cudssConfigCreate, [:pointer], :int
|
|
93
|
+
attach_function :cudssConfigDestroy, [:pointer], :int
|
|
94
|
+
|
|
95
|
+
# Data object: cudssDataCreate(handle, data) -> handle in, data out
|
|
96
|
+
attach_function :cudssDataCreate, [:pointer, :pointer], :int
|
|
97
|
+
attach_function :cudssDataDestroy, [:pointer, :pointer], :int
|
|
98
|
+
|
|
99
|
+
# Execution
|
|
100
|
+
attach_function :cudssExecute, [:pointer, :int, :pointer, :pointer, :pointer, :pointer, :pointer], :int
|
|
101
|
+
end
|
|
102
|
+
end
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# Specialized error for cuDSS
|
|
106
|
+
class CuDSSError < StandardError
|
|
107
|
+
attr_reader :status_code
|
|
108
|
+
|
|
109
|
+
def initialize(message, status_code: nil)
|
|
110
|
+
@status_code = status_code
|
|
111
|
+
super("#{message} (Status: #{status_code})")
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
end
|