ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,364 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module FFT
5
+ # FFT operations using cuFFT
6
+ module Operations
7
+ class << self
8
+ # 1D Forward FFT
9
+ # @param x [NvArray] Input array
10
+ # @param axis [Integer] Axis along which to compute FFT (-1 for last)
11
+ # @param norm [Symbol] Normalization mode (:forward, :backward, :ortho)
12
+ # @param stream [CUDA::Stream, nil] CUDA stream
13
+ # @return [NvArray] Complex output
14
+ def fft(x, axis: -1, norm: :backward, stream: nil)
15
+ validate_input!(x)
16
+ axis = normalize_axis(axis, x.ndim)
17
+
18
+ n = x.shape[axis]
19
+ output_dtype = DType.complex_dtype(x.dtype)
20
+
21
+ result = execute_fft(x, output_dtype, :forward, [n], stream)
22
+ apply_normalization(result, n, :forward, norm)
23
+ end
24
+
25
+ # 1D Inverse FFT
26
+ # @param x [NvArray] Input complex array
27
+ # @param axis [Integer] Axis along which to compute IFFT
28
+ # @param norm [Symbol] Normalization mode (:forward, :backward, :ortho)
29
+ # @param stream [CUDA::Stream, nil] CUDA stream
30
+ # @return [NvArray] Complex output
31
+ def ifft(x, axis: -1, norm: :backward, stream: nil)
32
+ validate_complex_input!(x)
33
+ axis = normalize_axis(axis, x.ndim)
34
+
35
+ n = x.shape[axis]
36
+
37
+ result = execute_fft(x, x.dtype, :inverse, [n], stream)
38
+ apply_normalization(result, n, :inverse, norm)
39
+ end
40
+
41
+ # 2D Forward FFT
42
+ # @param x [NvArray] Input 2D array
43
+ # @param axes [Array<Integer>, nil] Axes for FFT (default last two)
44
+ # @param norm [Symbol] Normalization mode
45
+ # @param stream [CUDA::Stream, nil] CUDA stream
46
+ # @return [NvArray] Complex output
47
+ def fft2(x, axes: nil, norm: :backward, stream: nil)
48
+ validate_input!(x)
49
+ raise DimensionError, "Input must be at least 2D for fft2" if x.ndim < 2
50
+
51
+ axes ||= [-2, -1]
52
+ axes = axes.map { |a| normalize_axis(a, x.ndim) }
53
+
54
+ n = axes.map { |a| x.shape[a] }
55
+ output_dtype = DType.complex_dtype(x.dtype)
56
+
57
+ result = execute_fft(x, output_dtype, :forward, n, stream)
58
+ apply_normalization(result, n.reduce(:*), :forward, norm)
59
+ end
60
+
61
+ # 2D Inverse FFT
62
+ # @param x [NvArray] Input complex 2D array
63
+ # @param axes [Array<Integer>, nil] Axes for IFFT
64
+ # @param norm [Symbol] Normalization mode
65
+ # @param stream [CUDA::Stream, nil] CUDA stream
66
+ # @return [NvArray] Complex output
67
+ def ifft2(x, axes: nil, norm: :backward, stream: nil)
68
+ validate_complex_input!(x)
69
+ raise DimensionError, "Input must be at least 2D for ifft2" if x.ndim < 2
70
+
71
+ axes ||= [-2, -1]
72
+ axes = axes.map { |a| normalize_axis(a, x.ndim) }
73
+
74
+ n = axes.map { |a| x.shape[a] }
75
+
76
+ result = execute_fft(x, x.dtype, :inverse, n, stream)
77
+ apply_normalization(result, n.reduce(:*), :inverse, norm)
78
+ end
79
+
80
+ # N-dimensional FFT
81
+ # @param x [NvArray] Input array
82
+ # @param axes [Array<Integer>, nil] Axes for FFT
83
+ # @param norm [Symbol] Normalization mode
84
+ # @param stream [CUDA::Stream, nil] CUDA stream
85
+ # @return [NvArray] Complex output
86
+ def fftn(x, axes: nil, norm: :backward, stream: nil)
87
+ validate_input!(x)
88
+
89
+ axes ||= (0...x.ndim).to_a
90
+ axes = axes.map { |a| normalize_axis(a, x.ndim) }
91
+
92
+ n = axes.map { |a| x.shape[a] }
93
+ output_dtype = DType.complex_dtype(x.dtype)
94
+
95
+ result = execute_fft(x, output_dtype, :forward, n, stream)
96
+ apply_normalization(result, n.reduce(:*), :forward, norm)
97
+ end
98
+
99
+ # N-dimensional Inverse FFT
100
+ # @param x [NvArray] Input complex array
101
+ # @param axes [Array<Integer>, nil] Axes for IFFT
102
+ # @param norm [Symbol] Normalization mode
103
+ # @param stream [CUDA::Stream, nil] CUDA stream
104
+ # @return [NvArray] Complex output
105
+ def ifftn(x, axes: nil, norm: :backward, stream: nil)
106
+ validate_complex_input!(x)
107
+
108
+ axes ||= (0...x.ndim).to_a
109
+ axes = axes.map { |a| normalize_axis(a, x.ndim) }
110
+
111
+ n = axes.map { |a| x.shape[a] }
112
+
113
+ result = execute_fft(x, x.dtype, :inverse, n, stream)
114
+ apply_normalization(result, n.reduce(:*), :inverse, norm)
115
+ end
116
+
117
+ # Real-to-complex FFT
118
+ # @param x [NvArray] Real input array
119
+ # @param n [Integer, nil] FFT length (output length = n/2+1)
120
+ # @param axis [Integer] Axis for FFT
121
+ # @param norm [Symbol] Normalization mode
122
+ # @param stream [CUDA::Stream, nil] CUDA stream
123
+ # @return [NvArray] Complex output (hermitian-symmetric)
124
+ def rfft(x, n: nil, axis: -1, norm: :backward, stream: nil)
125
+ validate_input!(x)
126
+ raise ArgumentError, "Input must be real for rfft" if DType.complex?(x.dtype)
127
+
128
+ axis = normalize_axis(axis, x.ndim)
129
+ n ||= x.shape[axis]
130
+
131
+ output_shape = x.shape.dup
132
+ output_shape[axis] = n / 2 + 1
133
+ output_dtype = DType.complex_dtype(x.dtype)
134
+
135
+ result = execute_rfft(x, output_shape, output_dtype, n, stream)
136
+ apply_normalization(result, n, :forward, norm)
137
+ end
138
+
139
+ # Complex-to-real Inverse FFT
140
+ # @param x [NvArray] Complex input array (hermitian-symmetric)
141
+ # @param n [Integer, nil] Output length
142
+ # @param axis [Integer] Axis for FFT
143
+ # @param norm [Symbol] Normalization mode
144
+ # @param stream [CUDA::Stream, nil] CUDA stream
145
+ # @return [NvArray] Real output
146
+ def irfft(x, n: nil, axis: -1, norm: :backward, stream: nil)
147
+ validate_complex_input!(x)
148
+
149
+ axis = normalize_axis(axis, x.ndim)
150
+ input_size = x.shape[axis]
151
+ n ||= (input_size - 1) * 2
152
+
153
+ output_shape = x.shape.dup
154
+ output_shape[axis] = n
155
+ output_dtype = DType.real_dtype(x.dtype)
156
+
157
+ result = execute_irfft(x, output_shape, output_dtype, n, stream)
158
+ apply_normalization(result, n, :inverse, norm)
159
+ end
160
+
161
+ private
162
+
163
+ # Validate input is an NvArray
164
+ def validate_input!(x)
165
+ raise ArgumentError, "Expected NvArray, got #{x.class}" unless x.is_a?(NvArray)
166
+ raise DimensionError, "Input must have at least 1 dimension" if x.ndim.zero?
167
+ end
168
+
169
+ # Validate input is complex
170
+ def validate_complex_input!(x)
171
+ validate_input!(x)
172
+ raise ArgumentError, "Expected complex dtype, got #{x.dtype}" unless DType.complex?(x.dtype)
173
+ end
174
+
175
+ # Normalize negative axis
176
+ # @param axis [Integer]
177
+ # @param ndim [Integer]
178
+ # @return [Integer]
179
+ def normalize_axis(axis, ndim)
180
+ axis += ndim if axis.negative?
181
+ raise DimensionError, "Axis #{axis} out of bounds for #{ndim}D array" unless axis >= 0 && axis < ndim
182
+
183
+ axis
184
+ end
185
+
186
+ # Execute FFT operation
187
+ # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
188
+ def execute_fft(x, output_dtype, direction, dimensions, stream)
189
+ CuFFTBindings.ensure_loaded!
190
+
191
+ x = x.to_device unless x.on_device?
192
+
193
+ # Create output array
194
+ output = NvArray.new(shape: x.shape, dtype: output_dtype, device: x.device_index)
195
+ output.to_device
196
+
197
+ # Create plan based on dtype and dimensions
198
+ plan_ptr = FFI::MemoryPointer.new(:pointer)
199
+
200
+ case dimensions.size
201
+ when 1
202
+ fft_type = determine_fft_type(x.dtype, output_dtype)
203
+ status = CuFFTBindings.cufftPlan1d(plan_ptr, dimensions[0], fft_type, 1)
204
+ when 2
205
+ fft_type = determine_fft_type(x.dtype, output_dtype)
206
+ status = CuFFTBindings.cufftPlan2d(plan_ptr, dimensions[0], dimensions[1], fft_type)
207
+ when 3
208
+ fft_type = determine_fft_type(x.dtype, output_dtype)
209
+ status = CuFFTBindings.cufftPlan3d(plan_ptr, dimensions[0], dimensions[1], dimensions[2], fft_type)
210
+ else
211
+ raise NotImplementedError, "FFT for #{dimensions.size}D not yet implemented"
212
+ end
213
+
214
+ CuFFTBindings.check_status!(status, "Create FFT plan")
215
+
216
+ plan = plan_ptr.read_pointer
217
+
218
+ # Set stream if provided
219
+ if stream
220
+ status = CuFFTBindings.cufftSetStream(plan, stream.handle)
221
+ CuFFTBindings.check_status!(status, "Set FFT stream")
222
+ end
223
+
224
+ # Execute
225
+ dir = direction == :forward ? CuFFTBindings::CUFFT_FORWARD : CuFFTBindings::CUFFT_INVERSE
226
+
227
+ case output_dtype
228
+ when :complex64
229
+ status = CuFFTBindings.cufftExecC2C(plan, x.device_ffi_ptr, output.device_ffi_ptr, dir)
230
+ when :complex128
231
+ status = CuFFTBindings.cufftExecZ2Z(plan, x.device_ffi_ptr, output.device_ffi_ptr, dir)
232
+ else
233
+ raise UnsupportedDTypeError.new(output_dtype, operation: "FFT")
234
+ end
235
+
236
+ CuFFTBindings.check_status!(status, "Execute FFT")
237
+
238
+ # Cleanup
239
+ CuFFTBindings.cufftDestroy(plan)
240
+
241
+ output
242
+ end
243
+ # rubocop:enable Metrics/AbcSize, Metrics/MethodLength
244
+
245
+ # Execute real-to-complex FFT
246
+ def execute_rfft(x, output_shape, output_dtype, n, stream)
247
+ CuFFTBindings.ensure_loaded!
248
+
249
+ x = x.to_device unless x.on_device?
250
+
251
+ output = NvArray.new(shape: output_shape, dtype: output_dtype, device: x.device_index)
252
+ output.to_device
253
+
254
+ plan_ptr = FFI::MemoryPointer.new(:pointer)
255
+ fft_type = x.dtype == :float32 ? CuFFTBindings::CUFFT_R2C : CuFFTBindings::CUFFT_D2Z
256
+
257
+ status = CuFFTBindings.cufftPlan1d(plan_ptr, n, fft_type, 1)
258
+ CuFFTBindings.check_status!(status, "Create RFFT plan")
259
+
260
+ plan = plan_ptr.read_pointer
261
+
262
+ if stream
263
+ CuFFTBindings.cufftSetStream(plan, stream.handle)
264
+ end
265
+
266
+ status = if x.dtype == :float32
267
+ CuFFTBindings.cufftExecR2C(plan, x.device_ffi_ptr, output.device_ffi_ptr)
268
+ else
269
+ CuFFTBindings.cufftExecD2Z(plan, x.device_ffi_ptr, output.device_ffi_ptr)
270
+ end
271
+
272
+ CuFFTBindings.check_status!(status, "Execute RFFT")
273
+ CuFFTBindings.cufftDestroy(plan)
274
+
275
+ output
276
+ end
277
+
278
+ # Execute complex-to-real IFFT
279
+ def execute_irfft(x, output_shape, output_dtype, n, stream)
280
+ CuFFTBindings.ensure_loaded!
281
+
282
+ x = x.to_device unless x.on_device?
283
+
284
+ output = NvArray.new(shape: output_shape, dtype: output_dtype, device: x.device_index)
285
+ output.to_device
286
+
287
+ plan_ptr = FFI::MemoryPointer.new(:pointer)
288
+ fft_type = output_dtype == :float32 ? CuFFTBindings::CUFFT_C2R : CuFFTBindings::CUFFT_Z2D
289
+
290
+ status = CuFFTBindings.cufftPlan1d(plan_ptr, n, fft_type, 1)
291
+ CuFFTBindings.check_status!(status, "Create IRFFT plan")
292
+
293
+ plan = plan_ptr.read_pointer
294
+
295
+ if stream
296
+ CuFFTBindings.cufftSetStream(plan, stream.handle)
297
+ end
298
+
299
+ status = if output_dtype == :float32
300
+ CuFFTBindings.cufftExecC2R(plan, x.device_ffi_ptr, output.device_ffi_ptr)
301
+ else
302
+ CuFFTBindings.cufftExecZ2D(plan, x.device_ffi_ptr, output.device_ffi_ptr)
303
+ end
304
+
305
+ CuFFTBindings.check_status!(status, "Execute IRFFT")
306
+ CuFFTBindings.cufftDestroy(plan)
307
+
308
+ output
309
+ end
310
+
311
+ # Determine FFT type based on input/output dtypes
312
+ def determine_fft_type(input_dtype, output_dtype)
313
+ case [input_dtype, output_dtype]
314
+ when %i[float32 complex64], %i[complex64 complex64]
315
+ CuFFTBindings::CUFFT_C2C
316
+ when %i[float64 complex128], %i[complex128 complex128]
317
+ CuFFTBindings::CUFFT_Z2Z
318
+ else
319
+ raise UnsupportedDTypeError.new("#{input_dtype}->#{output_dtype}", operation: "FFT")
320
+ end
321
+ end
322
+
323
+ # Apply normalization to FFT result
324
+ # @param result [NvArray] FFT result
325
+ # @param n [Integer] Transform size
326
+ # @param direction [Symbol] :forward or :inverse
327
+ # @param norm [Symbol] :forward, :backward, or :ortho
328
+ # @return [NvArray] Normalized result
329
+ def apply_normalization(result, n, direction, norm)
330
+ scale = case norm
331
+ when :backward
332
+ direction == :inverse ? 1.0 / n : 1.0
333
+ when :forward
334
+ direction == :forward ? 1.0 / n : 1.0
335
+ when :ortho
336
+ 1.0 / Math.sqrt(n)
337
+ else
338
+ 1.0
339
+ end
340
+
341
+ return result if scale == 1.0
342
+
343
+ # Actually apply the scale (previously a no-op, so ifft was off by N and
344
+ # ortho transforms off by sqrt(N); cuFFT does not normalize). Use cuBLAS
345
+ # scal over the raw float/double components (complex = 2 reals/element).
346
+ Ignis::LinAlg::CuBLASBindings.ensure_loaded!
347
+ handle = Ignis::LinAlg::CuBLASBindings.get_handle
348
+ count = DType.complex?(result.dtype) ? result.size * 2 : result.size
349
+ ptr = result.device_ffi_ptr
350
+
351
+ if %i[float64 complex128].include?(result.dtype)
352
+ alpha = FFI::MemoryPointer.new(:double).tap { |p| p.put_double(0, scale) }
353
+ status = Ignis::LinAlg::CuBLASBindings.cublasDscal_v2(handle, count, alpha, ptr, 1)
354
+ else
355
+ alpha = FFI::MemoryPointer.new(:float).tap { |p| p.put_float(0, scale) }
356
+ status = Ignis::LinAlg::CuBLASBindings.cublasSscal_v2(handle, count, alpha, ptr, 1)
357
+ end
358
+ Ignis::LinAlg::CuBLASBindings.check_status!(status, "FFT normalization scal")
359
+ result
360
+ end
361
+ end
362
+ end
363
+ end
364
+ end
@@ -0,0 +1,107 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module Linalg
7
+ # FFI bindings for NVIDIA cuTENSOR v2.4+
8
+ # High-performance tensor primitives for CUDA
9
+ module CuTensorBindings
10
+ extend FFI::Library
11
+
12
+ # cuTENSOR status codes
13
+ CUTENSOR_STATUS_SUCCESS = 0
14
+ CUTENSOR_STATUS_NOT_INITIALIZED = 1
15
+ CUTENSOR_STATUS_ALLOC_FAILED = 2
16
+ CUTENSOR_STATUS_INVALID_VALUE = 3
17
+ CUTENSOR_STATUS_ARCH_MISMATCH = 4
18
+ CUTENSOR_STATUS_MAPPING_ERROR = 5
19
+ CUTENSOR_STATUS_EXECUTION_FAILED = 6
20
+ CUTENSOR_STATUS_INTERNAL_ERROR = 7
21
+ CUTENSOR_STATUS_NOT_SUPPORTED = 8
22
+ CUTENSOR_STATUS_LICENSE_ERROR = 9
23
+ CUTENSOR_STATUS_CUBLAS_ERROR = 10
24
+ CUTENSOR_STATUS_CUDA_ERROR = 11
25
+ CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE = 12
26
+ CUTENSOR_STATUS_INSUFFICIENT_DRIVER = 13
27
+
28
+ # cuTENSOR Algorithms
29
+ CUTENSOR_ALGO_DEFAULT = -1
30
+ CUTENSOR_ALGO_GETT = -2
31
+ CUTENSOR_ALGO_TGETT = -3
32
+
33
+ # Workspace preference enums
34
+ CUTENSOR_WORKSPACE_MIN = 1
35
+ CUTENSOR_WORKSPACE_RECOMMENDED = 2
36
+ CUTENSOR_WORKSPACE_MAX = 3
37
+
38
+ @loaded = false
39
+ @mutex = Mutex.new
40
+
41
+ class << self
42
+ # Ensure cuTENSOR library is loaded and functions attached
43
+ # @return [void]
44
+ def ensure_loaded!
45
+ @mutex.synchronize do
46
+ return if @loaded
47
+
48
+ CUDA::LibraryLoader.load_library(:cutensor)
49
+ dll_path = CUDA::LibraryLoader.library_paths[:cutensor]
50
+
51
+ raise LibraryNotFoundError, "cutensor" unless dll_path
52
+
53
+ ffi_lib dll_path
54
+ attach_functions!
55
+
56
+ @loaded = true
57
+ Ignis.logger.debug("cuTENSOR bindings loaded from #{dll_path}")
58
+ end
59
+ end
60
+
61
+ # Check status and raise error if not success
62
+ # @param status [Integer] status code
63
+ # @param context [String] Context for error message
64
+ # @raise [CuTensorError] If status is not success
65
+ def check_status!(status, context = "cuTENSOR operation")
66
+ return if status == CUTENSOR_STATUS_SUCCESS
67
+
68
+ raise CuTensorError.new("#{context} failed", status_code: status)
69
+ end
70
+
71
+ private
72
+
73
+ def attach_functions!
74
+ # Handle management
75
+ attach_function :cutensorCreate, [:pointer], :int
76
+ attach_function :cutensorDestroy, [:pointer], :int
77
+
78
+ # Tensor descriptors
79
+ attach_function :cutensorCreateTensorDescriptor, [:pointer, :uint32, :pointer, :pointer, :int, :int], :int
80
+ attach_function :cutensorDestroyTensorDescriptor, [:pointer], :int
81
+
82
+ # Operation descriptors
83
+ attach_function :cutensorCreateContractionDescriptor, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :int, :int], :int
84
+ attach_function :cutensorDestroyOperationDescriptor, [:pointer], :int
85
+
86
+ # Plan management
87
+ attach_function :cutensorCreatePlan, [:pointer, :pointer, :pointer, :int, :uint64], :int
88
+ attach_function :cutensorDestroyPlan, [:pointer], :int
89
+
90
+ # Execution
91
+ attach_function :cutensorContract, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :uint64, :pointer], :int
92
+ attach_function :cutensorReduce, [:pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :pointer, :uint64, :pointer], :int
93
+ end
94
+ end
95
+ end
96
+
97
+ # Specialized error for cuTENSOR
98
+ class CuTensorError < StandardError
99
+ attr_reader :status_code
100
+
101
+ def initialize(message, status_code: nil)
102
+ @status_code = status_code
103
+ super("#{message} (Status: #{status_code})")
104
+ end
105
+ end
106
+ end
107
+ end