ignis-numerics 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-numerics.rb +62 -0
- data/lib/nvruby/array.rb +646 -0
- data/lib/nvruby/fft/cufft_bindings.rb +134 -0
- data/lib/nvruby/fft/fft_plan.rb +288 -0
- data/lib/nvruby/fft/operations.rb +364 -0
- data/lib/nvruby/linalg/cutensor_bindings.rb +107 -0
- data/lib/nvruby/mathdx/fft_kernel.rb +258 -0
- data/lib/nvruby/mathdx/gemm_kernel.rb +293 -0
- data/lib/nvruby/mathdx.rb +73 -0
- data/lib/nvruby/random/curand_bindings.rb +115 -0
- data/lib/nvruby/random/generator.rb +305 -0
- data/lib/nvruby/solver/amgx_bindings.rb +172 -0
- data/lib/nvruby/solver/amgx_config.rb +142 -0
- data/lib/nvruby/solver/amgx_solver.rb +251 -0
- data/lib/nvruby/solver/cudss_bindings.rb +115 -0
- data/lib/nvruby/solver/cusolver_bindings.rb +358 -0
- data/lib/nvruby/solver/eigen.rb +226 -0
- data/lib/nvruby/solver/lu.rb +265 -0
- data/lib/nvruby/solver/sparse_solver.rb +429 -0
- data/lib/nvruby/solver/svd.rb +266 -0
- data/lib/nvruby/solver.rb +122 -0
- data/lib/nvruby/sparse/cusparse_bindings.rb +231 -0
- data/lib/nvruby/sparse/sparse_matrix.rb +456 -0
- data/lib/nvruby/tensor/contraction.rb +218 -0
- data/lib/nvruby/tensor.rb +42 -0
- metadata +85 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module FFT
|
|
5
|
+
# FFT operations using cuFFT
|
|
6
|
+
module Operations
|
|
7
|
+
class << self
|
|
8
|
+
# 1D Forward FFT
|
|
9
|
+
# @param x [NvArray] Input array
|
|
10
|
+
# @param axis [Integer] Axis along which to compute FFT (-1 for last)
|
|
11
|
+
# @param norm [Symbol] Normalization mode (:forward, :backward, :ortho)
|
|
12
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
13
|
+
# @return [NvArray] Complex output
|
|
14
|
+
def fft(x, axis: -1, norm: :backward, stream: nil)
|
|
15
|
+
validate_input!(x)
|
|
16
|
+
axis = normalize_axis(axis, x.ndim)
|
|
17
|
+
|
|
18
|
+
n = x.shape[axis]
|
|
19
|
+
output_dtype = DType.complex_dtype(x.dtype)
|
|
20
|
+
|
|
21
|
+
result = execute_fft(x, output_dtype, :forward, [n], stream)
|
|
22
|
+
apply_normalization(result, n, :forward, norm)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# 1D Inverse FFT
|
|
26
|
+
# @param x [NvArray] Input complex array
|
|
27
|
+
# @param axis [Integer] Axis along which to compute IFFT
|
|
28
|
+
# @param norm [Symbol] Normalization mode (:forward, :backward, :ortho)
|
|
29
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
30
|
+
# @return [NvArray] Complex output
|
|
31
|
+
def ifft(x, axis: -1, norm: :backward, stream: nil)
|
|
32
|
+
validate_complex_input!(x)
|
|
33
|
+
axis = normalize_axis(axis, x.ndim)
|
|
34
|
+
|
|
35
|
+
n = x.shape[axis]
|
|
36
|
+
|
|
37
|
+
result = execute_fft(x, x.dtype, :inverse, [n], stream)
|
|
38
|
+
apply_normalization(result, n, :inverse, norm)
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# 2D Forward FFT
|
|
42
|
+
# @param x [NvArray] Input 2D array
|
|
43
|
+
# @param axes [Array<Integer>, nil] Axes for FFT (default last two)
|
|
44
|
+
# @param norm [Symbol] Normalization mode
|
|
45
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
46
|
+
# @return [NvArray] Complex output
|
|
47
|
+
def fft2(x, axes: nil, norm: :backward, stream: nil)
|
|
48
|
+
validate_input!(x)
|
|
49
|
+
raise DimensionError, "Input must be at least 2D for fft2" if x.ndim < 2
|
|
50
|
+
|
|
51
|
+
axes ||= [-2, -1]
|
|
52
|
+
axes = axes.map { |a| normalize_axis(a, x.ndim) }
|
|
53
|
+
|
|
54
|
+
n = axes.map { |a| x.shape[a] }
|
|
55
|
+
output_dtype = DType.complex_dtype(x.dtype)
|
|
56
|
+
|
|
57
|
+
result = execute_fft(x, output_dtype, :forward, n, stream)
|
|
58
|
+
apply_normalization(result, n.reduce(:*), :forward, norm)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# 2D Inverse FFT
|
|
62
|
+
# @param x [NvArray] Input complex 2D array
|
|
63
|
+
# @param axes [Array<Integer>, nil] Axes for IFFT
|
|
64
|
+
# @param norm [Symbol] Normalization mode
|
|
65
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
66
|
+
# @return [NvArray] Complex output
|
|
67
|
+
def ifft2(x, axes: nil, norm: :backward, stream: nil)
|
|
68
|
+
validate_complex_input!(x)
|
|
69
|
+
raise DimensionError, "Input must be at least 2D for ifft2" if x.ndim < 2
|
|
70
|
+
|
|
71
|
+
axes ||= [-2, -1]
|
|
72
|
+
axes = axes.map { |a| normalize_axis(a, x.ndim) }
|
|
73
|
+
|
|
74
|
+
n = axes.map { |a| x.shape[a] }
|
|
75
|
+
|
|
76
|
+
result = execute_fft(x, x.dtype, :inverse, n, stream)
|
|
77
|
+
apply_normalization(result, n.reduce(:*), :inverse, norm)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
# N-dimensional FFT
|
|
81
|
+
# @param x [NvArray] Input array
|
|
82
|
+
# @param axes [Array<Integer>, nil] Axes for FFT
|
|
83
|
+
# @param norm [Symbol] Normalization mode
|
|
84
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
85
|
+
# @return [NvArray] Complex output
|
|
86
|
+
def fftn(x, axes: nil, norm: :backward, stream: nil)
|
|
87
|
+
validate_input!(x)
|
|
88
|
+
|
|
89
|
+
axes ||= (0...x.ndim).to_a
|
|
90
|
+
axes = axes.map { |a| normalize_axis(a, x.ndim) }
|
|
91
|
+
|
|
92
|
+
n = axes.map { |a| x.shape[a] }
|
|
93
|
+
output_dtype = DType.complex_dtype(x.dtype)
|
|
94
|
+
|
|
95
|
+
result = execute_fft(x, output_dtype, :forward, n, stream)
|
|
96
|
+
apply_normalization(result, n.reduce(:*), :forward, norm)
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# N-dimensional Inverse FFT
|
|
100
|
+
# @param x [NvArray] Input complex array
|
|
101
|
+
# @param axes [Array<Integer>, nil] Axes for IFFT
|
|
102
|
+
# @param norm [Symbol] Normalization mode
|
|
103
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
104
|
+
# @return [NvArray] Complex output
|
|
105
|
+
def ifftn(x, axes: nil, norm: :backward, stream: nil)
|
|
106
|
+
validate_complex_input!(x)
|
|
107
|
+
|
|
108
|
+
axes ||= (0...x.ndim).to_a
|
|
109
|
+
axes = axes.map { |a| normalize_axis(a, x.ndim) }
|
|
110
|
+
|
|
111
|
+
n = axes.map { |a| x.shape[a] }
|
|
112
|
+
|
|
113
|
+
result = execute_fft(x, x.dtype, :inverse, n, stream)
|
|
114
|
+
apply_normalization(result, n.reduce(:*), :inverse, norm)
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
# Real-to-complex FFT
|
|
118
|
+
# @param x [NvArray] Real input array
|
|
119
|
+
# @param n [Integer, nil] FFT length (output length = n/2+1)
|
|
120
|
+
# @param axis [Integer] Axis for FFT
|
|
121
|
+
# @param norm [Symbol] Normalization mode
|
|
122
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
123
|
+
# @return [NvArray] Complex output (hermitian-symmetric)
|
|
124
|
+
def rfft(x, n: nil, axis: -1, norm: :backward, stream: nil)
|
|
125
|
+
validate_input!(x)
|
|
126
|
+
raise ArgumentError, "Input must be real for rfft" if DType.complex?(x.dtype)
|
|
127
|
+
|
|
128
|
+
axis = normalize_axis(axis, x.ndim)
|
|
129
|
+
n ||= x.shape[axis]
|
|
130
|
+
|
|
131
|
+
output_shape = x.shape.dup
|
|
132
|
+
output_shape[axis] = n / 2 + 1
|
|
133
|
+
output_dtype = DType.complex_dtype(x.dtype)
|
|
134
|
+
|
|
135
|
+
result = execute_rfft(x, output_shape, output_dtype, n, stream)
|
|
136
|
+
apply_normalization(result, n, :forward, norm)
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
# Complex-to-real Inverse FFT
|
|
140
|
+
# @param x [NvArray] Complex input array (hermitian-symmetric)
|
|
141
|
+
# @param n [Integer, nil] Output length
|
|
142
|
+
# @param axis [Integer] Axis for FFT
|
|
143
|
+
# @param norm [Symbol] Normalization mode
|
|
144
|
+
# @param stream [CUDA::Stream, nil] CUDA stream
|
|
145
|
+
# @return [NvArray] Real output
|
|
146
|
+
def irfft(x, n: nil, axis: -1, norm: :backward, stream: nil)
|
|
147
|
+
validate_complex_input!(x)
|
|
148
|
+
|
|
149
|
+
axis = normalize_axis(axis, x.ndim)
|
|
150
|
+
input_size = x.shape[axis]
|
|
151
|
+
n ||= (input_size - 1) * 2
|
|
152
|
+
|
|
153
|
+
output_shape = x.shape.dup
|
|
154
|
+
output_shape[axis] = n
|
|
155
|
+
output_dtype = DType.real_dtype(x.dtype)
|
|
156
|
+
|
|
157
|
+
result = execute_irfft(x, output_shape, output_dtype, n, stream)
|
|
158
|
+
apply_normalization(result, n, :inverse, norm)
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
private
|
|
162
|
+
|
|
163
|
+
# Validate input is an NvArray
|
|
164
|
+
def validate_input!(x)
|
|
165
|
+
raise ArgumentError, "Expected NvArray, got #{x.class}" unless x.is_a?(NvArray)
|
|
166
|
+
raise DimensionError, "Input must have at least 1 dimension" if x.ndim.zero?
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
# Validate input is complex
|
|
170
|
+
def validate_complex_input!(x)
|
|
171
|
+
validate_input!(x)
|
|
172
|
+
raise ArgumentError, "Expected complex dtype, got #{x.dtype}" unless DType.complex?(x.dtype)
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
# Normalize negative axis
|
|
176
|
+
# @param axis [Integer]
|
|
177
|
+
# @param ndim [Integer]
|
|
178
|
+
# @return [Integer]
|
|
179
|
+
def normalize_axis(axis, ndim)
|
|
180
|
+
axis += ndim if axis.negative?
|
|
181
|
+
raise DimensionError, "Axis #{axis} out of bounds for #{ndim}D array" unless axis >= 0 && axis < ndim
|
|
182
|
+
|
|
183
|
+
axis
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
# Execute FFT operation
|
|
187
|
+
# rubocop:disable Metrics/AbcSize, Metrics/MethodLength
|
|
188
|
+
def execute_fft(x, output_dtype, direction, dimensions, stream)
|
|
189
|
+
CuFFTBindings.ensure_loaded!
|
|
190
|
+
|
|
191
|
+
x = x.to_device unless x.on_device?
|
|
192
|
+
|
|
193
|
+
# Create output array
|
|
194
|
+
output = NvArray.new(shape: x.shape, dtype: output_dtype, device: x.device_index)
|
|
195
|
+
output.to_device
|
|
196
|
+
|
|
197
|
+
# Create plan based on dtype and dimensions
|
|
198
|
+
plan_ptr = FFI::MemoryPointer.new(:pointer)
|
|
199
|
+
|
|
200
|
+
case dimensions.size
|
|
201
|
+
when 1
|
|
202
|
+
fft_type = determine_fft_type(x.dtype, output_dtype)
|
|
203
|
+
status = CuFFTBindings.cufftPlan1d(plan_ptr, dimensions[0], fft_type, 1)
|
|
204
|
+
when 2
|
|
205
|
+
fft_type = determine_fft_type(x.dtype, output_dtype)
|
|
206
|
+
status = CuFFTBindings.cufftPlan2d(plan_ptr, dimensions[0], dimensions[1], fft_type)
|
|
207
|
+
when 3
|
|
208
|
+
fft_type = determine_fft_type(x.dtype, output_dtype)
|
|
209
|
+
status = CuFFTBindings.cufftPlan3d(plan_ptr, dimensions[0], dimensions[1], dimensions[2], fft_type)
|
|
210
|
+
else
|
|
211
|
+
raise NotImplementedError, "FFT for #{dimensions.size}D not yet implemented"
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
CuFFTBindings.check_status!(status, "Create FFT plan")
|
|
215
|
+
|
|
216
|
+
plan = plan_ptr.read_pointer
|
|
217
|
+
|
|
218
|
+
# Set stream if provided
|
|
219
|
+
if stream
|
|
220
|
+
status = CuFFTBindings.cufftSetStream(plan, stream.handle)
|
|
221
|
+
CuFFTBindings.check_status!(status, "Set FFT stream")
|
|
222
|
+
end
|
|
223
|
+
|
|
224
|
+
# Execute
|
|
225
|
+
dir = direction == :forward ? CuFFTBindings::CUFFT_FORWARD : CuFFTBindings::CUFFT_INVERSE
|
|
226
|
+
|
|
227
|
+
case output_dtype
|
|
228
|
+
when :complex64
|
|
229
|
+
status = CuFFTBindings.cufftExecC2C(plan, x.device_ffi_ptr, output.device_ffi_ptr, dir)
|
|
230
|
+
when :complex128
|
|
231
|
+
status = CuFFTBindings.cufftExecZ2Z(plan, x.device_ffi_ptr, output.device_ffi_ptr, dir)
|
|
232
|
+
else
|
|
233
|
+
raise UnsupportedDTypeError.new(output_dtype, operation: "FFT")
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
CuFFTBindings.check_status!(status, "Execute FFT")
|
|
237
|
+
|
|
238
|
+
# Cleanup
|
|
239
|
+
CuFFTBindings.cufftDestroy(plan)
|
|
240
|
+
|
|
241
|
+
output
|
|
242
|
+
end
|
|
243
|
+
# rubocop:enable Metrics/AbcSize, Metrics/MethodLength
|
|
244
|
+
|
|
245
|
+
# Execute real-to-complex FFT
|
|
246
|
+
def execute_rfft(x, output_shape, output_dtype, n, stream)
|
|
247
|
+
CuFFTBindings.ensure_loaded!
|
|
248
|
+
|
|
249
|
+
x = x.to_device unless x.on_device?
|
|
250
|
+
|
|
251
|
+
output = NvArray.new(shape: output_shape, dtype: output_dtype, device: x.device_index)
|
|
252
|
+
output.to_device
|
|
253
|
+
|
|
254
|
+
plan_ptr = FFI::MemoryPointer.new(:pointer)
|
|
255
|
+
fft_type = x.dtype == :float32 ? CuFFTBindings::CUFFT_R2C : CuFFTBindings::CUFFT_D2Z
|
|
256
|
+
|
|
257
|
+
status = CuFFTBindings.cufftPlan1d(plan_ptr, n, fft_type, 1)
|
|
258
|
+
CuFFTBindings.check_status!(status, "Create RFFT plan")
|
|
259
|
+
|
|
260
|
+
plan = plan_ptr.read_pointer
|
|
261
|
+
|
|
262
|
+
if stream
|
|
263
|
+
CuFFTBindings.cufftSetStream(plan, stream.handle)
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
status = if x.dtype == :float32
|
|
267
|
+
CuFFTBindings.cufftExecR2C(plan, x.device_ffi_ptr, output.device_ffi_ptr)
|
|
268
|
+
else
|
|
269
|
+
CuFFTBindings.cufftExecD2Z(plan, x.device_ffi_ptr, output.device_ffi_ptr)
|
|
270
|
+
end
|
|
271
|
+
|
|
272
|
+
CuFFTBindings.check_status!(status, "Execute RFFT")
|
|
273
|
+
CuFFTBindings.cufftDestroy(plan)
|
|
274
|
+
|
|
275
|
+
output
|
|
276
|
+
end
|
|
277
|
+
|
|
278
|
+
# Execute complex-to-real IFFT
|
|
279
|
+
def execute_irfft(x, output_shape, output_dtype, n, stream)
|
|
280
|
+
CuFFTBindings.ensure_loaded!
|
|
281
|
+
|
|
282
|
+
x = x.to_device unless x.on_device?
|
|
283
|
+
|
|
284
|
+
output = NvArray.new(shape: output_shape, dtype: output_dtype, device: x.device_index)
|
|
285
|
+
output.to_device
|
|
286
|
+
|
|
287
|
+
plan_ptr = FFI::MemoryPointer.new(:pointer)
|
|
288
|
+
fft_type = output_dtype == :float32 ? CuFFTBindings::CUFFT_C2R : CuFFTBindings::CUFFT_Z2D
|
|
289
|
+
|
|
290
|
+
status = CuFFTBindings.cufftPlan1d(plan_ptr, n, fft_type, 1)
|
|
291
|
+
CuFFTBindings.check_status!(status, "Create IRFFT plan")
|
|
292
|
+
|
|
293
|
+
plan = plan_ptr.read_pointer
|
|
294
|
+
|
|
295
|
+
if stream
|
|
296
|
+
CuFFTBindings.cufftSetStream(plan, stream.handle)
|
|
297
|
+
end
|
|
298
|
+
|
|
299
|
+
status = if output_dtype == :float32
|
|
300
|
+
CuFFTBindings.cufftExecC2R(plan, x.device_ffi_ptr, output.device_ffi_ptr)
|
|
301
|
+
else
|
|
302
|
+
CuFFTBindings.cufftExecZ2D(plan, x.device_ffi_ptr, output.device_ffi_ptr)
|
|
303
|
+
end
|
|
304
|
+
|
|
305
|
+
CuFFTBindings.check_status!(status, "Execute IRFFT")
|
|
306
|
+
CuFFTBindings.cufftDestroy(plan)
|
|
307
|
+
|
|
308
|
+
output
|
|
309
|
+
end
|
|
310
|
+
|
|
311
|
+
# Determine FFT type based on input/output dtypes
|
|
312
|
+
def determine_fft_type(input_dtype, output_dtype)
|
|
313
|
+
case [input_dtype, output_dtype]
|
|
314
|
+
when %i[float32 complex64], %i[complex64 complex64]
|
|
315
|
+
CuFFTBindings::CUFFT_C2C
|
|
316
|
+
when %i[float64 complex128], %i[complex128 complex128]
|
|
317
|
+
CuFFTBindings::CUFFT_Z2Z
|
|
318
|
+
else
|
|
319
|
+
raise UnsupportedDTypeError.new("#{input_dtype}->#{output_dtype}", operation: "FFT")
|
|
320
|
+
end
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
# Apply normalization to FFT result
|
|
324
|
+
# @param result [NvArray] FFT result
|
|
325
|
+
# @param n [Integer] Transform size
|
|
326
|
+
# @param direction [Symbol] :forward or :inverse
|
|
327
|
+
# @param norm [Symbol] :forward, :backward, or :ortho
|
|
328
|
+
# @return [NvArray] Normalized result
|
|
329
|
+
def apply_normalization(result, n, direction, norm)
|
|
330
|
+
scale = case norm
|
|
331
|
+
when :backward
|
|
332
|
+
direction == :inverse ? 1.0 / n : 1.0
|
|
333
|
+
when :forward
|
|
334
|
+
direction == :forward ? 1.0 / n : 1.0
|
|
335
|
+
when :ortho
|
|
336
|
+
1.0 / Math.sqrt(n)
|
|
337
|
+
else
|
|
338
|
+
1.0
|
|
339
|
+
end
|
|
340
|
+
|
|
341
|
+
return result if scale == 1.0
|
|
342
|
+
|
|
343
|
+
# Actually apply the scale (previously a no-op, so ifft was off by N and
|
|
344
|
+
# ortho transforms off by sqrt(N); cuFFT does not normalize). Use cuBLAS
|
|
345
|
+
# scal over the raw float/double components (complex = 2 reals/element).
|
|
346
|
+
Ignis::LinAlg::CuBLASBindings.ensure_loaded!
|
|
347
|
+
handle = Ignis::LinAlg::CuBLASBindings.get_handle
|
|
348
|
+
count = DType.complex?(result.dtype) ? result.size * 2 : result.size
|
|
349
|
+
ptr = result.device_ffi_ptr
|
|
350
|
+
|
|
351
|
+
if %i[float64 complex128].include?(result.dtype)
|
|
352
|
+
alpha = FFI::MemoryPointer.new(:double).tap { |p| p.put_double(0, scale) }
|
|
353
|
+
status = Ignis::LinAlg::CuBLASBindings.cublasDscal_v2(handle, count, alpha, ptr, 1)
|
|
354
|
+
else
|
|
355
|
+
alpha = FFI::MemoryPointer.new(:float).tap { |p| p.put_float(0, scale) }
|
|
356
|
+
status = Ignis::LinAlg::CuBLASBindings.cublasSscal_v2(handle, count, alpha, ptr, 1)
|
|
357
|
+
end
|
|
358
|
+
Ignis::LinAlg::CuBLASBindings.check_status!(status, "FFT normalization scal")
|
|
359
|
+
result
|
|
360
|
+
end
|
|
361
|
+
end
|
|
362
|
+
end
|
|
363
|
+
end
|
|
364
|
+
end
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "ffi"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module Linalg
|
|
7
|
+
# FFI bindings for NVIDIA cuTENSOR v2.4+
|
|
8
|
+
# High-performance tensor primitives for CUDA
|
|
9
|
+
module CuTensorBindings
|
|
10
|
+
extend FFI::Library
|
|
11
|
+
|
|
12
|
+
# cuTENSOR status codes
|
|
13
|
+
CUTENSOR_STATUS_SUCCESS = 0
|
|
14
|
+
CUTENSOR_STATUS_NOT_INITIALIZED = 1
|
|
15
|
+
CUTENSOR_STATUS_ALLOC_FAILED = 2
|
|
16
|
+
CUTENSOR_STATUS_INVALID_VALUE = 3
|
|
17
|
+
CUTENSOR_STATUS_ARCH_MISMATCH = 4
|
|
18
|
+
CUTENSOR_STATUS_MAPPING_ERROR = 5
|
|
19
|
+
CUTENSOR_STATUS_EXECUTION_FAILED = 6
|
|
20
|
+
CUTENSOR_STATUS_INTERNAL_ERROR = 7
|
|
21
|
+
CUTENSOR_STATUS_NOT_SUPPORTED = 8
|
|
22
|
+
CUTENSOR_STATUS_LICENSE_ERROR = 9
|
|
23
|
+
CUTENSOR_STATUS_CUBLAS_ERROR = 10
|
|
24
|
+
CUTENSOR_STATUS_CUDA_ERROR = 11
|
|
25
|
+
CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE = 12
|
|
26
|
+
CUTENSOR_STATUS_INSUFFICIENT_DRIVER = 13
|
|
27
|
+
|
|
28
|
+
# cuTENSOR Algorithms
|
|
29
|
+
CUTENSOR_ALGO_DEFAULT = -1
|
|
30
|
+
CUTENSOR_ALGO_GETT = -2
|
|
31
|
+
CUTENSOR_ALGO_TGETT = -3
|
|
32
|
+
|
|
33
|
+
# Workspace preference enums
|
|
34
|
+
CUTENSOR_WORKSPACE_MIN = 1
|
|
35
|
+
CUTENSOR_WORKSPACE_RECOMMENDED = 2
|
|
36
|
+
CUTENSOR_WORKSPACE_MAX = 3
|
|
37
|
+
|
|
38
|
+
@loaded = false
|
|
39
|
+
@mutex = Mutex.new
|
|
40
|
+
|
|
41
|
+
class << self
|
|
42
|
+
# Ensure cuTENSOR library is loaded and functions attached
|
|
43
|
+
# @return [void]
|
|
44
|
+
def ensure_loaded!
|
|
45
|
+
@mutex.synchronize do
|
|
46
|
+
return if @loaded
|
|
47
|
+
|
|
48
|
+
CUDA::LibraryLoader.load_library(:cutensor)
|
|
49
|
+
dll_path = CUDA::LibraryLoader.library_paths[:cutensor]
|
|
50
|
+
|
|
51
|
+
raise LibraryNotFoundError, "cutensor" unless dll_path
|
|
52
|
+
|
|
53
|
+
ffi_lib dll_path
|
|
54
|
+
attach_functions!
|
|
55
|
+
|
|
56
|
+
@loaded = true
|
|
57
|
+
Ignis.logger.debug("cuTENSOR bindings loaded from #{dll_path}")
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Check status and raise error if not success
|
|
62
|
+
# @param status [Integer] status code
|
|
63
|
+
# @param context [String] Context for error message
|
|
64
|
+
# @raise [CuTensorError] If status is not success
|
|
65
|
+
def check_status!(status, context = "cuTENSOR operation")
|
|
66
|
+
return if status == CUTENSOR_STATUS_SUCCESS
|
|
67
|
+
|
|
68
|
+
raise CuTensorError.new("#{context} failed", status_code: status)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
private
|
|
72
|
+
|
|
73
|
+
def attach_functions!
|
|
74
|
+
# Handle management
|
|
75
|
+
attach_function :cutensorCreate, [:pointer], :int
|
|
76
|
+
attach_function :cutensorDestroy, [:pointer], :int
|
|
77
|
+
|
|
78
|
+
# Tensor descriptors
|
|
79
|
+
attach_function :cutensorCreateTensorDescriptor, [:pointer, :uint32, :pointer, :pointer, :int, :int], :int
|
|
80
|
+
attach_function :cutensorDestroyTensorDescriptor, [:pointer], :int
|
|
81
|
+
|
|
82
|
+
# Operation descriptors
|
|
83
|
+
attach_function :cutensorCreateContractionDescriptor, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :int, :int], :int
|
|
84
|
+
attach_function :cutensorDestroyOperationDescriptor, [:pointer], :int
|
|
85
|
+
|
|
86
|
+
# Plan management
|
|
87
|
+
attach_function :cutensorCreatePlan, [:pointer, :pointer, :pointer, :int, :uint64], :int
|
|
88
|
+
attach_function :cutensorDestroyPlan, [:pointer], :int
|
|
89
|
+
|
|
90
|
+
# Execution
|
|
91
|
+
attach_function :cutensorContract, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :uint64, :pointer], :int
|
|
92
|
+
attach_function :cutensorReduce, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :uint64, :pointer], :int
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
# Specialized error for cuTENSOR
|
|
98
|
+
class CuTensorError < StandardError
|
|
99
|
+
attr_reader :status_code
|
|
100
|
+
|
|
101
|
+
def initialize(message, status_code: nil)
|
|
102
|
+
@status_code = status_code
|
|
103
|
+
super("#{message} (Status: #{status_code})")
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
end
|