ignis-numerics 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-numerics.rb +62 -0
- data/lib/nvruby/array.rb +646 -0
- data/lib/nvruby/fft/cufft_bindings.rb +134 -0
- data/lib/nvruby/fft/fft_plan.rb +288 -0
- data/lib/nvruby/fft/operations.rb +364 -0
- data/lib/nvruby/linalg/cutensor_bindings.rb +107 -0
- data/lib/nvruby/mathdx/fft_kernel.rb +258 -0
- data/lib/nvruby/mathdx/gemm_kernel.rb +293 -0
- data/lib/nvruby/mathdx.rb +73 -0
- data/lib/nvruby/random/curand_bindings.rb +115 -0
- data/lib/nvruby/random/generator.rb +305 -0
- data/lib/nvruby/solver/amgx_bindings.rb +172 -0
- data/lib/nvruby/solver/amgx_config.rb +142 -0
- data/lib/nvruby/solver/amgx_solver.rb +251 -0
- data/lib/nvruby/solver/cudss_bindings.rb +115 -0
- data/lib/nvruby/solver/cusolver_bindings.rb +358 -0
- data/lib/nvruby/solver/eigen.rb +226 -0
- data/lib/nvruby/solver/lu.rb +265 -0
- data/lib/nvruby/solver/sparse_solver.rb +429 -0
- data/lib/nvruby/solver/svd.rb +266 -0
- data/lib/nvruby/solver.rb +122 -0
- data/lib/nvruby/sparse/cusparse_bindings.rb +231 -0
- data/lib/nvruby/sparse/sparse_matrix.rb +456 -0
- data/lib/nvruby/tensor/contraction.rb +218 -0
- data/lib/nvruby/tensor.rb +42 -0
- metadata +85 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "ffi"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Solver
|
|
7
|
+
# cuSOLVER Dense (Dn) library FFI bindings
|
|
8
|
+
# Provides LU decomposition, SVD, and eigenvalue solvers
|
|
9
|
+
module CuSolverBindings
|
|
10
|
+
extend FFI::Library
|
|
11
|
+
|
|
12
|
+
# cuSOLVER status codes
|
|
13
|
+
CUSOLVER_STATUS_SUCCESS = 0
|
|
14
|
+
CUSOLVER_STATUS_NOT_INITIALIZED = 1
|
|
15
|
+
CUSOLVER_STATUS_ALLOC_FAILED = 2
|
|
16
|
+
CUSOLVER_STATUS_INVALID_VALUE = 3
|
|
17
|
+
CUSOLVER_STATUS_ARCH_MISMATCH = 4
|
|
18
|
+
CUSOLVER_STATUS_MAPPING_ERROR = 5
|
|
19
|
+
CUSOLVER_STATUS_EXECUTION_FAILED = 6
|
|
20
|
+
CUSOLVER_STATUS_INTERNAL_ERROR = 7
|
|
21
|
+
CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED = 8
|
|
22
|
+
CUSOLVER_STATUS_NOT_SUPPORTED = 9
|
|
23
|
+
CUSOLVER_STATUS_ZERO_PIVOT = 10
|
|
24
|
+
CUSOLVER_STATUS_INVALID_LICENSE = 11
|
|
25
|
+
CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED = 12
|
|
26
|
+
CUSOLVER_STATUS_IRS_PARAMS_INVALID = 13
|
|
27
|
+
CUSOLVER_STATUS_IRS_PARAMS_INVALID_PREC = 14
|
|
28
|
+
CUSOLVER_STATUS_IRS_PARAMS_INVALID_REFINE = 15
|
|
29
|
+
CUSOLVER_STATUS_IRS_PARAMS_INVALID_MAXITER = 16
|
|
30
|
+
CUSOLVER_STATUS_IRS_INTERNAL_ERROR = 20
|
|
31
|
+
CUSOLVER_STATUS_IRS_NOT_SUPPORTED = 21
|
|
32
|
+
CUSOLVER_STATUS_IRS_OUT_OF_RANGE = 22
|
|
33
|
+
CUSOLVER_STATUS_IRS_NRHS_NOT_SUPPORTED_FOR_REFINE_GMRES = 23
|
|
34
|
+
CUSOLVER_STATUS_IRS_INFOS_NOT_INITIALIZED = 25
|
|
35
|
+
|
|
36
|
+
# cuSOLVER EigMode (for eigenvalue/eigenvector computation)
|
|
37
|
+
CUSOLVER_EIG_MODE_NOVECTOR = 0 # Compute eigenvalues only
|
|
38
|
+
CUSOLVER_EIG_MODE_VECTOR = 1 # Compute eigenvalues and eigenvectors
|
|
39
|
+
|
|
40
|
+
# cuSOLVER EigType (for generalized eigenvalue problem)
|
|
41
|
+
CUSOLVER_EIG_TYPE_1 = 1 # A*x = lambda*B*x
|
|
42
|
+
CUSOLVER_EIG_TYPE_2 = 2 # A*B*x = lambda*x
|
|
43
|
+
CUSOLVER_EIG_TYPE_3 = 3 # B*A*x = lambda*x
|
|
44
|
+
|
|
45
|
+
# cuSOLVER EigRange (eigenvalue range selection)
|
|
46
|
+
CUSOLVER_EIG_RANGE_ALL = 0 # All eigenvalues
|
|
47
|
+
CUSOLVER_EIG_RANGE_V = 1 # Eigenvalues in half-open interval (vl, vu]
|
|
48
|
+
CUSOLVER_EIG_RANGE_I = 2 # il-th through iu-th eigenvalues
|
|
49
|
+
|
|
50
|
+
# cuBLAS fill mode (used by cuSOLVER)
|
|
51
|
+
CUBLAS_FILL_MODE_LOWER = 0
|
|
52
|
+
CUBLAS_FILL_MODE_UPPER = 1
|
|
53
|
+
CUBLAS_FILL_MODE_FULL = 2
|
|
54
|
+
|
|
55
|
+
# cuBLAS operation types
|
|
56
|
+
CUBLAS_OP_N = 0 # No transpose
|
|
57
|
+
CUBLAS_OP_T = 1 # Transpose
|
|
58
|
+
CUBLAS_OP_C = 2 # Conjugate transpose
|
|
59
|
+
|
|
60
|
+
# Status code descriptions
|
|
61
|
+
STATUS_DESCRIPTIONS = {
|
|
62
|
+
CUSOLVER_STATUS_SUCCESS => "Success",
|
|
63
|
+
CUSOLVER_STATUS_NOT_INITIALIZED => "Library not initialized",
|
|
64
|
+
CUSOLVER_STATUS_ALLOC_FAILED => "Resource allocation failed",
|
|
65
|
+
CUSOLVER_STATUS_INVALID_VALUE => "Invalid parameter value",
|
|
66
|
+
CUSOLVER_STATUS_ARCH_MISMATCH => "Architecture mismatch",
|
|
67
|
+
CUSOLVER_STATUS_MAPPING_ERROR => "Mapping error",
|
|
68
|
+
CUSOLVER_STATUS_EXECUTION_FAILED => "Execution failed",
|
|
69
|
+
CUSOLVER_STATUS_INTERNAL_ERROR => "Internal error",
|
|
70
|
+
CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED => "Matrix type not supported",
|
|
71
|
+
CUSOLVER_STATUS_NOT_SUPPORTED => "Operation not supported",
|
|
72
|
+
CUSOLVER_STATUS_ZERO_PIVOT => "Zero pivot encountered",
|
|
73
|
+
CUSOLVER_STATUS_INVALID_LICENSE => "Invalid license"
|
|
74
|
+
}.freeze
|
|
75
|
+
|
|
76
|
+
@loaded = false
|
|
77
|
+
@handle = nil
|
|
78
|
+
@mutex = Mutex.new
|
|
79
|
+
|
|
80
|
+
class << self
|
|
81
|
+
# @return [FFI::Pointer, nil] cuSOLVER handle
|
|
82
|
+
attr_accessor :handle
|
|
83
|
+
|
|
84
|
+
# Check if cuSOLVER is loaded
|
|
85
|
+
# @return [Boolean]
|
|
86
|
+
def loaded?
|
|
87
|
+
@loaded
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Ensure cuSOLVER is loaded and initialized
|
|
91
|
+
# @return [void]
|
|
92
|
+
# @raise [CuSolverError] If initialization fails
|
|
93
|
+
def ensure_loaded!
|
|
94
|
+
@mutex.synchronize do
|
|
95
|
+
return if @loaded
|
|
96
|
+
|
|
97
|
+
CUDA::LibraryLoader.load_library(:cusolver)
|
|
98
|
+
|
|
99
|
+
cuda_bin = Ignis.configuration.cuda_bin_path
|
|
100
|
+
dll_path = if cuda_bin
|
|
101
|
+
Dir.glob(File.join(cuda_bin, "cusolver64_*.dll")).max
|
|
102
|
+
else
|
|
103
|
+
"cusolver64_12"
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
raise LibraryNotFoundError, "cusolver" unless dll_path
|
|
107
|
+
|
|
108
|
+
ffi_lib dll_path
|
|
109
|
+
attach_cusolver_functions!
|
|
110
|
+
initialize_cusolver!
|
|
111
|
+
|
|
112
|
+
@loaded = true
|
|
113
|
+
Ignis.logger.info("cuSOLVER initialized successfully")
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
# Get or create cuSOLVER handle
|
|
118
|
+
# @return [FFI::Pointer]
|
|
119
|
+
def get_handle
|
|
120
|
+
ensure_loaded!
|
|
121
|
+
@handle
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# Set the CUDA stream for cuSOLVER operations
|
|
125
|
+
# @param stream [FFI::Pointer] CUDA stream
|
|
126
|
+
# @return [void]
|
|
127
|
+
def set_stream(stream)
|
|
128
|
+
ensure_loaded!
|
|
129
|
+
status = cusolverDnSetStream(@handle, stream)
|
|
130
|
+
check_status!(status, "cusolverDnSetStream")
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
# Finalize cuSOLVER and release resources
|
|
134
|
+
# @return [void]
|
|
135
|
+
def finalize!
|
|
136
|
+
@mutex.synchronize do
|
|
137
|
+
return unless @handle
|
|
138
|
+
|
|
139
|
+
cusolverDnDestroy(@handle)
|
|
140
|
+
@handle = nil
|
|
141
|
+
@loaded = false
|
|
142
|
+
Ignis.logger.info("cuSOLVER finalized")
|
|
143
|
+
end
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
# Check cuSOLVER status and raise error if not success
|
|
147
|
+
# @param status [Integer] cuSOLVER status code
|
|
148
|
+
# @param context [String] Context for error message
|
|
149
|
+
# @return [void]
|
|
150
|
+
# @raise [CuSolverError] If status indicates an error
|
|
151
|
+
def check_status!(status, context = "cuSOLVER operation")
|
|
152
|
+
return if status == CUSOLVER_STATUS_SUCCESS
|
|
153
|
+
|
|
154
|
+
description = STATUS_DESCRIPTIONS[status] || "Unknown error"
|
|
155
|
+
raise CuSolverError.new("#{context} failed: #{description}", cusolver_code: status)
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
private
|
|
159
|
+
|
|
160
|
+
# Attach all cuSOLVER FFI functions
|
|
161
|
+
# @return [void]
|
|
162
|
+
def attach_cusolver_functions!
|
|
163
|
+
# Handle management
|
|
164
|
+
attach_function :cusolverDnCreate, [:pointer], :int
|
|
165
|
+
attach_function :cusolverDnDestroy, [:pointer], :int
|
|
166
|
+
attach_function :cusolverDnSetStream, [:pointer, :pointer], :int
|
|
167
|
+
attach_function :cusolverDnGetStream, [:pointer, :pointer], :int
|
|
168
|
+
|
|
169
|
+
# LU factorization - buffer size queries
|
|
170
|
+
attach_function :cusolverDnSgetrf_bufferSize, [
|
|
171
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
172
|
+
], :int
|
|
173
|
+
attach_function :cusolverDnDgetrf_bufferSize, [
|
|
174
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
175
|
+
], :int
|
|
176
|
+
attach_function :cusolverDnCgetrf_bufferSize, [
|
|
177
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
178
|
+
], :int
|
|
179
|
+
attach_function :cusolverDnZgetrf_bufferSize, [
|
|
180
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
181
|
+
], :int
|
|
182
|
+
|
|
183
|
+
# LU factorization
|
|
184
|
+
attach_function :cusolverDnSgetrf, [
|
|
185
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
|
|
186
|
+
], :int
|
|
187
|
+
attach_function :cusolverDnDgetrf, [
|
|
188
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
|
|
189
|
+
], :int
|
|
190
|
+
attach_function :cusolverDnCgetrf, [
|
|
191
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
|
|
192
|
+
], :int
|
|
193
|
+
attach_function :cusolverDnZgetrf, [
|
|
194
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :pointer
|
|
195
|
+
], :int
|
|
196
|
+
|
|
197
|
+
# LU solve (getrs)
|
|
198
|
+
attach_function :cusolverDnSgetrs, [
|
|
199
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
200
|
+
], :int
|
|
201
|
+
attach_function :cusolverDnDgetrs, [
|
|
202
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
203
|
+
], :int
|
|
204
|
+
attach_function :cusolverDnCgetrs, [
|
|
205
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
206
|
+
], :int
|
|
207
|
+
attach_function :cusolverDnZgetrs, [
|
|
208
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
209
|
+
], :int
|
|
210
|
+
|
|
211
|
+
# SVD buffer size queries
|
|
212
|
+
attach_function :cusolverDnSgesvd_bufferSize, [
|
|
213
|
+
:pointer, :int, :int, :pointer
|
|
214
|
+
], :int
|
|
215
|
+
attach_function :cusolverDnDgesvd_bufferSize, [
|
|
216
|
+
:pointer, :int, :int, :pointer
|
|
217
|
+
], :int
|
|
218
|
+
attach_function :cusolverDnCgesvd_bufferSize, [
|
|
219
|
+
:pointer, :int, :int, :pointer
|
|
220
|
+
], :int
|
|
221
|
+
attach_function :cusolverDnZgesvd_bufferSize, [
|
|
222
|
+
:pointer, :int, :int, :pointer
|
|
223
|
+
], :int
|
|
224
|
+
|
|
225
|
+
# SVD computation
|
|
226
|
+
attach_function :cusolverDnSgesvd, [
|
|
227
|
+
:pointer, :char, :char, :int, :int, :pointer, :int,
|
|
228
|
+
:pointer, :pointer, :int, :pointer, :int,
|
|
229
|
+
:pointer, :int, :pointer, :pointer
|
|
230
|
+
], :int
|
|
231
|
+
attach_function :cusolverDnDgesvd, [
|
|
232
|
+
:pointer, :char, :char, :int, :int, :pointer, :int,
|
|
233
|
+
:pointer, :pointer, :int, :pointer, :int,
|
|
234
|
+
:pointer, :int, :pointer, :pointer
|
|
235
|
+
], :int
|
|
236
|
+
attach_function :cusolverDnCgesvd, [
|
|
237
|
+
:pointer, :char, :char, :int, :int, :pointer, :int,
|
|
238
|
+
:pointer, :pointer, :int, :pointer, :int,
|
|
239
|
+
:pointer, :int, :pointer, :pointer
|
|
240
|
+
], :int
|
|
241
|
+
attach_function :cusolverDnZgesvd, [
|
|
242
|
+
:pointer, :char, :char, :int, :int, :pointer, :int,
|
|
243
|
+
:pointer, :pointer, :int, :pointer, :int,
|
|
244
|
+
:pointer, :int, :pointer, :pointer
|
|
245
|
+
], :int
|
|
246
|
+
|
|
247
|
+
# Symmetric eigenvalue solver buffer size (syevd)
|
|
248
|
+
attach_function :cusolverDnSsyevd_bufferSize, [
|
|
249
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
|
|
250
|
+
], :int
|
|
251
|
+
attach_function :cusolverDnDsyevd_bufferSize, [
|
|
252
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
|
|
253
|
+
], :int
|
|
254
|
+
attach_function :cusolverDnCheevd_bufferSize, [
|
|
255
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
|
|
256
|
+
], :int
|
|
257
|
+
attach_function :cusolverDnZheevd_bufferSize, [
|
|
258
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer
|
|
259
|
+
], :int
|
|
260
|
+
|
|
261
|
+
# Symmetric eigenvalue solver (syevd/heevd)
|
|
262
|
+
attach_function :cusolverDnSsyevd, [
|
|
263
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
264
|
+
], :int
|
|
265
|
+
attach_function :cusolverDnDsyevd, [
|
|
266
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
267
|
+
], :int
|
|
268
|
+
attach_function :cusolverDnCheevd, [
|
|
269
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
270
|
+
], :int
|
|
271
|
+
attach_function :cusolverDnZheevd, [
|
|
272
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
273
|
+
], :int
|
|
274
|
+
|
|
275
|
+
# Cholesky factorization buffer size
|
|
276
|
+
attach_function :cusolverDnSpotrf_bufferSize, [
|
|
277
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
278
|
+
], :int
|
|
279
|
+
attach_function :cusolverDnDpotrf_bufferSize, [
|
|
280
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
281
|
+
], :int
|
|
282
|
+
attach_function :cusolverDnCpotrf_bufferSize, [
|
|
283
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
284
|
+
], :int
|
|
285
|
+
attach_function :cusolverDnZpotrf_bufferSize, [
|
|
286
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
287
|
+
], :int
|
|
288
|
+
|
|
289
|
+
# Cholesky factorization
|
|
290
|
+
attach_function :cusolverDnSpotrf, [
|
|
291
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
292
|
+
], :int
|
|
293
|
+
attach_function :cusolverDnDpotrf, [
|
|
294
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
295
|
+
], :int
|
|
296
|
+
attach_function :cusolverDnCpotrf, [
|
|
297
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
298
|
+
], :int
|
|
299
|
+
attach_function :cusolverDnZpotrf, [
|
|
300
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
301
|
+
], :int
|
|
302
|
+
|
|
303
|
+
# Cholesky solve
|
|
304
|
+
attach_function :cusolverDnSpotrs, [
|
|
305
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
306
|
+
], :int
|
|
307
|
+
attach_function :cusolverDnDpotrs, [
|
|
308
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
309
|
+
], :int
|
|
310
|
+
attach_function :cusolverDnCpotrs, [
|
|
311
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
312
|
+
], :int
|
|
313
|
+
attach_function :cusolverDnZpotrs, [
|
|
314
|
+
:pointer, :int, :int, :int, :pointer, :int, :pointer, :int, :pointer
|
|
315
|
+
], :int
|
|
316
|
+
|
|
317
|
+
# QR factorization buffer size
|
|
318
|
+
attach_function :cusolverDnSgeqrf_bufferSize, [
|
|
319
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
320
|
+
], :int
|
|
321
|
+
attach_function :cusolverDnDgeqrf_bufferSize, [
|
|
322
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
323
|
+
], :int
|
|
324
|
+
attach_function :cusolverDnCgeqrf_bufferSize, [
|
|
325
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
326
|
+
], :int
|
|
327
|
+
attach_function :cusolverDnZgeqrf_bufferSize, [
|
|
328
|
+
:pointer, :int, :int, :pointer, :int, :pointer
|
|
329
|
+
], :int
|
|
330
|
+
|
|
331
|
+
# QR factorization
|
|
332
|
+
attach_function :cusolverDnSgeqrf, [
|
|
333
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
334
|
+
], :int
|
|
335
|
+
attach_function :cusolverDnDgeqrf, [
|
|
336
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
337
|
+
], :int
|
|
338
|
+
attach_function :cusolverDnCgeqrf, [
|
|
339
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
340
|
+
], :int
|
|
341
|
+
attach_function :cusolverDnZgeqrf, [
|
|
342
|
+
:pointer, :int, :int, :pointer, :int, :pointer, :pointer, :int, :pointer
|
|
343
|
+
], :int
|
|
344
|
+
end
|
|
345
|
+
|
|
346
|
+
# Initialize cuSOLVER handle
|
|
347
|
+
# @return [void]
|
|
348
|
+
def initialize_cusolver!
|
|
349
|
+
handle_ptr = FFI::MemoryPointer.new(:pointer)
|
|
350
|
+
status = cusolverDnCreate(handle_ptr)
|
|
351
|
+
check_status!(status, "cusolverDnCreate")
|
|
352
|
+
|
|
353
|
+
@handle = handle_ptr.read_pointer
|
|
354
|
+
end
|
|
355
|
+
end
|
|
356
|
+
end
|
|
357
|
+
end
|
|
358
|
+
end
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "cusolver_bindings"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Solver
|
|
7
|
+
# Eigenvalue and Eigenvector computation using cuSOLVER
|
|
8
|
+
# Supports symmetric/Hermitian matrices (syevd/heevd) and general matrices (geev)
|
|
9
|
+
module Eigen
|
|
10
|
+
class << self
|
|
11
|
+
# Compute eigenvalues and eigenvectors of a symmetric/Hermitian matrix
|
|
12
|
+
# Uses divide-and-conquer algorithm (syevd/heevd)
|
|
13
|
+
# @param matrix [NvArray] Symmetric (real) or Hermitian (complex) matrix (n x n)
|
|
14
|
+
# @param eigenvectors [Boolean] If true, compute eigenvectors
|
|
15
|
+
# @param uplo [Symbol] :lower or :upper indicating which triangle contains the matrix
|
|
16
|
+
# @return [Hash] Contains :eigenvalues and optionally :eigenvectors
|
|
17
|
+
# @raise [CuSolverError] If computation fails
|
|
18
|
+
def eigh(matrix, eigenvectors: true, uplo: :lower)
|
|
19
|
+
CuSolverBindings.ensure_loaded!
|
|
20
|
+
|
|
21
|
+
validate_symmetric_matrix!(matrix)
|
|
22
|
+
n = matrix.shape[0]
|
|
23
|
+
lda = n
|
|
24
|
+
|
|
25
|
+
# Job mode
|
|
26
|
+
jobz = eigenvectors ? CuSolverBindings::CUSOLVER_EIG_MODE_VECTOR : CuSolverBindings::CUSOLVER_EIG_MODE_NOVECTOR
|
|
27
|
+
fill_mode = uplo == :lower ? CuSolverBindings::CUBLAS_FILL_MODE_LOWER : CuSolverBindings::CUBLAS_FILL_MODE_UPPER
|
|
28
|
+
|
|
29
|
+
# Allocate eigenvalues array (always real)
|
|
30
|
+
real_dtype = real_dtype_for(matrix.dtype)
|
|
31
|
+
eigenvalues = CUDA::Memory.new(n * dtype_size(real_dtype))
|
|
32
|
+
|
|
33
|
+
# Copy matrix to work array (will contain eigenvectors on output if requested)
|
|
34
|
+
work_matrix = matrix.dup
|
|
35
|
+
|
|
36
|
+
# Get workspace size
|
|
37
|
+
lwork_ptr = FFI::MemoryPointer.new(:int)
|
|
38
|
+
get_syevd_buffer_size(jobz, fill_mode, n, work_matrix.device_ptr, lda, eigenvalues, lwork_ptr, matrix.dtype)
|
|
39
|
+
lwork = lwork_ptr.read_int
|
|
40
|
+
|
|
41
|
+
# Allocate workspace and info
|
|
42
|
+
workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype))
|
|
43
|
+
info = CUDA::Memory.new(4)
|
|
44
|
+
|
|
45
|
+
# Perform eigenvalue computation
|
|
46
|
+
perform_syevd(jobz, fill_mode, n, work_matrix.device_ptr, lda, eigenvalues, workspace, lwork, info, matrix.dtype)
|
|
47
|
+
|
|
48
|
+
# Check info
|
|
49
|
+
info_value = read_device_int(info)
|
|
50
|
+
if info_value < 0
|
|
51
|
+
raise CuSolverError.new("Eigenvalue computation: parameter #{-info_value} had an illegal value",
|
|
52
|
+
cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE)
|
|
53
|
+
elsif info_value > 0
|
|
54
|
+
raise CuSolverError.new("Eigenvalue computation: #{info_value} off-diagonal elements did not converge to zero",
|
|
55
|
+
cusolver_code: CuSolverBindings::CUSOLVER_STATUS_EXECUTION_FAILED)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
CUDA::RuntimeAPI.cudaDeviceSynchronize
|
|
59
|
+
|
|
60
|
+
result = {
|
|
61
|
+
eigenvalues: NvArray.from_device_ptr(eigenvalues, shape: [n], dtype: real_dtype, take_ownership: true)
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if eigenvectors
|
|
65
|
+
result[:eigenvectors] = work_matrix
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
result
|
|
69
|
+
ensure
|
|
70
|
+
workspace&.free! if defined?(workspace) && workspace
|
|
71
|
+
info&.free! if defined?(info) && info
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Compute eigenvalues only of a symmetric/Hermitian matrix (faster)
|
|
75
|
+
# @param matrix [NvArray] Symmetric/Hermitian matrix
|
|
76
|
+
# @return [NvArray] Eigenvalues in ascending order
|
|
77
|
+
def eigvalsh(matrix)
|
|
78
|
+
result = eigh(matrix, eigenvectors: false)
|
|
79
|
+
result[:eigenvalues]
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Compute eigenvalues and eigenvectors of a general (non-symmetric) matrix
|
|
83
|
+
# Note: This is more expensive than eigh for symmetric matrices
|
|
84
|
+
# @param matrix [NvArray] General square matrix (n x n)
|
|
85
|
+
# @return [Hash] Contains :eigenvalues (may be complex), :eigenvectors_right
|
|
86
|
+
def eig(matrix)
|
|
87
|
+
CuSolverBindings.ensure_loaded!
|
|
88
|
+
|
|
89
|
+
validate_square_matrix!(matrix)
|
|
90
|
+
n = matrix.shape[0]
|
|
91
|
+
|
|
92
|
+
# For general matrices, eigenvalues may be complex
|
|
93
|
+
# Even for real input, we may get complex eigenvalues
|
|
94
|
+
# For now, we'll use Schur decomposition approximation
|
|
95
|
+
|
|
96
|
+
# Copy matrix
|
|
97
|
+
work_matrix = matrix.dup
|
|
98
|
+
|
|
99
|
+
# For real matrices, eigenvalues are returned as real/imaginary pairs
|
|
100
|
+
# For complex matrices, eigenvalues are complex directly
|
|
101
|
+
|
|
102
|
+
if complex_dtype?(matrix.dtype)
|
|
103
|
+
compute_complex_eig(work_matrix, n)
|
|
104
|
+
else
|
|
105
|
+
compute_real_eig(work_matrix, n)
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Check if a matrix is approximately symmetric
|
|
110
|
+
# @param matrix [NvArray] Matrix to check
|
|
111
|
+
# @param tol [Float] Tolerance for symmetry check
|
|
112
|
+
# @return [Boolean]
|
|
113
|
+
def symmetric?(matrix, tol: 1e-10)
|
|
114
|
+
return false unless matrix.shape[0] == matrix.shape[1]
|
|
115
|
+
|
|
116
|
+
# For a proper check, we'd compare A with A^T
|
|
117
|
+
# This is a simplified placeholder - full implementation would
|
|
118
|
+
# compute max(abs(A - A^T))
|
|
119
|
+
true # Placeholder - assumes caller knows matrix is symmetric
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
private
|
|
123
|
+
|
|
124
|
+
def validate_symmetric_matrix!(matrix)
|
|
125
|
+
raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
|
|
126
|
+
raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
|
|
127
|
+
raise ArgumentError, "Matrix must be square" unless matrix.shape[0] == matrix.shape[1]
|
|
128
|
+
raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def validate_square_matrix!(matrix)
|
|
132
|
+
raise ArgumentError, "Expected NvArray" unless matrix.is_a?(NvArray)
|
|
133
|
+
raise ArgumentError, "Matrix must be 2D" unless matrix.shape.length == 2
|
|
134
|
+
raise ArgumentError, "Matrix must be square" unless matrix.shape[0] == matrix.shape[1]
|
|
135
|
+
raise ArgumentError, "Matrix must be on device" unless matrix.on_device?
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
def dtype_size(dtype)
|
|
139
|
+
case dtype
|
|
140
|
+
when :float32 then 4
|
|
141
|
+
when :float64 then 8
|
|
142
|
+
when :complex64 then 8
|
|
143
|
+
when :complex128 then 16
|
|
144
|
+
else 4
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def real_dtype_for(dtype)
|
|
149
|
+
case dtype
|
|
150
|
+
when :complex64 then :float32
|
|
151
|
+
when :complex128 then :float64
|
|
152
|
+
else dtype
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
def complex_dtype?(dtype)
|
|
157
|
+
%i[complex64 complex128].include?(dtype)
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def read_device_int(ptr)
|
|
161
|
+
host_ptr = FFI::MemoryPointer.new(:int)
|
|
162
|
+
CUDA::RuntimeAPI.cudaMemcpy(host_ptr, ptr, 4, CUDA::RuntimeAPI::MEMCPY_DEVICE_TO_HOST)
|
|
163
|
+
host_ptr.read_int
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
def get_syevd_buffer_size(jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr, dtype)
|
|
167
|
+
handle = CuSolverBindings.get_handle
|
|
168
|
+
|
|
169
|
+
status = case dtype
|
|
170
|
+
when :float32
|
|
171
|
+
CuSolverBindings.cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
|
|
172
|
+
when :float64
|
|
173
|
+
CuSolverBindings.cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
|
|
174
|
+
when :complex64
|
|
175
|
+
CuSolverBindings.cusolverDnCheevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
|
|
176
|
+
when :complex128
|
|
177
|
+
CuSolverBindings.cusolverDnZheevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
|
|
178
|
+
else
|
|
179
|
+
CuSolverBindings.cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a_ptr, lda, w_ptr, lwork_ptr)
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
CuSolverBindings.check_status!(status, "cusolverDnXsyevd_bufferSize")
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def perform_syevd(jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info, dtype)
|
|
186
|
+
handle = CuSolverBindings.get_handle
|
|
187
|
+
|
|
188
|
+
status = case dtype
|
|
189
|
+
when :float32
|
|
190
|
+
CuSolverBindings.cusolverDnSsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
|
|
191
|
+
when :float64
|
|
192
|
+
CuSolverBindings.cusolverDnDsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
|
|
193
|
+
when :complex64
|
|
194
|
+
CuSolverBindings.cusolverDnCheevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
|
|
195
|
+
when :complex128
|
|
196
|
+
CuSolverBindings.cusolverDnZheevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
|
|
197
|
+
else
|
|
198
|
+
CuSolverBindings.cusolverDnSsyevd(handle, jobz, uplo, n, a_ptr, lda, w_ptr, work, lwork, info)
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
CuSolverBindings.check_status!(status, "cusolverDnXsyevd")
|
|
202
|
+
end
|
|
203
|
+
|
|
204
|
+
# General (non-symmetric) real eigensolver.
|
|
205
|
+
#
|
|
206
|
+
# NOT IMPLEMENTED: cuSOLVER's general eigensolver (geev) is not bound. The
|
|
207
|
+
# previous code silently ran the SYMMETRIC solver (syevd) on the raw matrix
|
|
208
|
+
# and only logged a warning, returning the eigenpairs of (A+Aᵀ)/2 — wrong
|
|
209
|
+
# (and possibly real-only) results for any non-symmetric A. Fail loudly
|
|
210
|
+
# instead; use #eigh for genuinely symmetric/Hermitian matrices.
|
|
211
|
+
def compute_real_eig(_matrix, _n)
|
|
212
|
+
raise NotImplementedError,
|
|
213
|
+
"General (non-symmetric) eigenvalue solver is not implemented (geev not bound). " \
|
|
214
|
+
"Use eigh/eigvalsh for symmetric/Hermitian matrices."
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
# General complex eigensolver — not implemented (see compute_real_eig).
|
|
218
|
+
def compute_complex_eig(_matrix, _n)
|
|
219
|
+
raise NotImplementedError,
|
|
220
|
+
"General complex eigenvalue solver is not implemented (geev not bound). " \
|
|
221
|
+
"Use eigh for Hermitian matrices."
|
|
222
|
+
end
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|