ignis-numerics 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,134 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module FFT
7
+ # cuFFT library FFI bindings
8
+ module CuFFTBindings
9
+ extend FFI::Library
10
+
11
+ # FFT transform types
12
+ CUFFT_R2C = 0x2a # Real to Complex (float)
13
+ CUFFT_C2R = 0x2c # Complex to Real (float)
14
+ CUFFT_C2C = 0x29 # Complex to Complex (float)
15
+ CUFFT_D2Z = 0x6a # Real to Complex (double)
16
+ CUFFT_Z2D = 0x6c # Complex to Real (double)
17
+ CUFFT_Z2Z = 0x69 # Complex to Complex (double)
18
+
19
+ # FFT directions
20
+ CUFFT_FORWARD = -1
21
+ CUFFT_INVERSE = 1
22
+
23
+ # Callback types
24
+ CUFFT_CB_LD_COMPLEX = 0
25
+ CUFFT_CB_LD_COMPLEX_DOUBLE = 1
26
+ CUFFT_CB_LD_REAL = 2
27
+ CUFFT_CB_LD_REAL_DOUBLE = 3
28
+ CUFFT_CB_ST_COMPLEX = 4
29
+ CUFFT_CB_ST_COMPLEX_DOUBLE = 5
30
+ CUFFT_CB_ST_REAL = 6
31
+ CUFFT_CB_ST_REAL_DOUBLE = 7
32
+
33
+ @loaded = false
34
+
35
+ class << self
36
+ # Ensure cuFFT is loaded
37
+ # @return [void]
38
+ def ensure_loaded!
39
+ return if @loaded
40
+
41
+ CUDA::LibraryLoader.load_library(:cufft)
42
+
43
+ cuda_bin = Ignis.configuration.cuda_bin_path
44
+ if cuda_bin
45
+ ffi_lib Dir.glob(File.join(cuda_bin, "cufft64_*.dll")).max
46
+ else
47
+ ffi_lib "cufft64_11"
48
+ end
49
+
50
+ attach_cufft_functions!
51
+ @loaded = true
52
+ end
53
+
54
+ private
55
+
56
+ # rubocop:disable Metrics/MethodLength
57
+ def attach_cufft_functions!
58
+ # Plan management
59
+ attach_function :cufftCreate, [:pointer], :int
60
+ attach_function :cufftDestroy, [:pointer], :int
61
+ attach_function :cufftSetStream, [:pointer, :pointer], :int
62
+ attach_function :cufftGetSize, [:pointer, :pointer], :int
63
+
64
+ # 1D plan
65
+ attach_function :cufftPlan1d, [:pointer, :int, :int, :int], :int
66
+ # 2D plan
67
+ attach_function :cufftPlan2d, [:pointer, :int, :int, :int], :int
68
+ # 3D plan
69
+ attach_function :cufftPlan3d, [:pointer, :int, :int, :int, :int], :int
70
+ # Many (batched) plan
71
+ attach_function :cufftPlanMany, [
72
+ :pointer, # plan
73
+ :int, # rank
74
+ :pointer, # n (dimensions)
75
+ :pointer, # inembed
76
+ :int, # istride
77
+ :int, # idist
78
+ :pointer, # onembed
79
+ :int, # ostride
80
+ :int, # odist
81
+ :int, # type
82
+ :int # batch
83
+ ], :int
84
+
85
+ # Plan with explicit memory
86
+ attach_function :cufftMakePlan1d, [:pointer, :int, :int, :int, :pointer], :int
87
+ attach_function :cufftMakePlan2d, [:pointer, :int, :int, :int, :pointer], :int
88
+ attach_function :cufftMakePlan3d, [:pointer, :int, :int, :int, :int, :pointer], :int
89
+ attach_function :cufftMakePlanMany, [
90
+ :pointer, :int, :pointer, :pointer, :int, :int,
91
+ :pointer, :int, :int, :int, :int, :pointer
92
+ ], :int
93
+
94
+ # Set workspace
95
+ attach_function :cufftSetAutoAllocation, [:pointer, :int], :int
96
+ attach_function :cufftSetWorkArea, [:pointer, :pointer], :int
97
+
98
+ # Execute - Complex to Complex (single precision)
99
+ attach_function :cufftExecC2C, [:pointer, :pointer, :pointer, :int], :int
100
+ # Execute - Complex to Complex (double precision)
101
+ attach_function :cufftExecZ2Z, [:pointer, :pointer, :pointer, :int], :int
102
+ # Execute - Real to Complex (single precision)
103
+ attach_function :cufftExecR2C, [:pointer, :pointer, :pointer], :int
104
+ # Execute - Complex to Real (single precision)
105
+ attach_function :cufftExecC2R, [:pointer, :pointer, :pointer], :int
106
+ # Execute - Real to Complex (double precision)
107
+ attach_function :cufftExecD2Z, [:pointer, :pointer, :pointer], :int
108
+ # Execute - Complex to Real (double precision)
109
+ attach_function :cufftExecZ2D, [:pointer, :pointer, :pointer], :int
110
+
111
+ # Estimate
112
+ attach_function :cufftEstimate1d, [:int, :int, :int, :pointer], :int
113
+ attach_function :cufftEstimate2d, [:int, :int, :int, :pointer], :int
114
+ attach_function :cufftEstimate3d, [:int, :int, :int, :int, :pointer], :int
115
+ attach_function :cufftEstimateMany, [
116
+ :int, :pointer, :pointer, :int, :int,
117
+ :pointer, :int, :int, :int, :int, :pointer
118
+ ], :int
119
+ end
120
+ # rubocop:enable Metrics/MethodLength
121
+ end
122
+
123
+ # Check cuFFT status and raise error if not success
124
+ # @param status [Integer] cuFFT status code
125
+ # @param context [String] Context for error message
126
+ # @return [void]
127
+ def self.check_status!(status, context = "cuFFT operation")
128
+ return if status.zero?
129
+
130
+ raise CuFFTError, status
131
+ end
132
+ end
133
+ end
134
+ end
@@ -0,0 +1,288 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module FFT
5
+ # Reusable FFT plan for repeated operations with same dimensions
6
+ class FFTPlan
7
+ # @return [Array<Integer>] Input shape
8
+ attr_reader :shape
9
+
10
+ # @return [Symbol] Input data type
11
+ attr_reader :dtype
12
+
13
+ # @return [Symbol] Transform direction (:forward, :inverse)
14
+ attr_reader :direction
15
+
16
+ # @return [Symbol] Transform type (:c2c, :r2c, :c2r)
17
+ attr_reader :transform_type
18
+
19
+ # @return [FFI::Pointer] cuFFT plan handle
20
+ attr_reader :plan_handle
21
+
22
+ # @param shape [Array<Integer>] Input shape
23
+ # @param dtype [Symbol] Input data type
24
+ # @param direction [Symbol] :forward or :inverse
25
+ # @param transform_type [Symbol] :c2c, :r2c, or :c2r
26
+ # @param batch [Integer] Batch size for multiple FFTs
27
+ def initialize(shape:, dtype: :complex64, direction: :forward, transform_type: :c2c, batch: 1)
28
+ @shape = Array(shape)
29
+ @dtype = DType.validate!(dtype)
30
+ @direction = direction
31
+ @transform_type = transform_type
32
+ @batch = batch
33
+
34
+ @plan_handle = nil
35
+ @destroyed = false
36
+
37
+ create_plan!
38
+ end
39
+
40
+ # @return [Integer] Number of dimensions
41
+ def ndim
42
+ @shape.size
43
+ end
44
+
45
+ # @return [Array<Integer>] Output shape
46
+ def output_shape
47
+ case @transform_type
48
+ when :r2c
49
+ out = @shape.dup
50
+ out[-1] = out[-1] / 2 + 1
51
+ out
52
+ when :c2r
53
+ out = @shape.dup
54
+ # Output size should be provided, default to (n-1)*2
55
+ out[-1] = (@shape[-1] - 1) * 2
56
+ out
57
+ else
58
+ @shape.dup
59
+ end
60
+ end
61
+
62
+ # @return [Symbol] Output data type
63
+ def output_dtype
64
+ case @transform_type
65
+ when :r2c
66
+ DType.complex_dtype(@dtype)
67
+ when :c2r
68
+ DType.real_dtype(@dtype)
69
+ else
70
+ @dtype
71
+ end
72
+ end
73
+
74
+ # Execute the FFT plan
75
+ # @param input [NvArray] Input array
76
+ # @param output [NvArray, nil] Output array (created if nil)
77
+ # @param stream [CUDA::Stream, nil] CUDA stream
78
+ # @return [NvArray] Result
79
+ def execute(input, output: nil, stream: nil)
80
+ raise InvalidOperationError, "FFTPlan has been destroyed" if @destroyed
81
+
82
+ validate_input!(input)
83
+
84
+ input = input.to_device unless input.on_device?
85
+
86
+ output ||= NvArray.new(shape: output_shape, dtype: output_dtype, device: input.device_index)
87
+ output = output.to_device unless output.on_device?
88
+
89
+ if stream
90
+ status = CuFFTBindings.cufftSetStream(@plan_handle, stream.handle)
91
+ CuFFTBindings.check_status!(status, "Set FFT stream")
92
+ end
93
+
94
+ direction_flag = @direction == :forward ? CuFFTBindings::CUFFT_FORWARD : CuFFTBindings::CUFFT_INVERSE
95
+
96
+ status = case @transform_type
97
+ when :c2c
98
+ execute_c2c(input, output, direction_flag)
99
+ when :r2c
100
+ execute_r2c(input, output)
101
+ when :c2r
102
+ execute_c2r(input, output)
103
+ else
104
+ raise InvalidOperationError, "Unknown transform type: #{@transform_type}"
105
+ end
106
+
107
+ CuFFTBindings.check_status!(status, "Execute FFT")
108
+
109
+ output
110
+ end
111
+
112
+ # Get estimated workspace size
113
+ # @return [Integer] Workspace size in bytes
114
+ def workspace_size
115
+ raise InvalidOperationError, "FFTPlan has been destroyed" if @destroyed
116
+
117
+ size_ptr = FFI::MemoryPointer.new(:size_t)
118
+ status = CuFFTBindings.cufftGetSize(@plan_handle, size_ptr)
119
+ CuFFTBindings.check_status!(status, "Get FFT workspace size")
120
+
121
+ size_ptr.read(:size_t)
122
+ end
123
+
124
+ # Destroy the plan and free resources
125
+ # @return [void]
126
+ def destroy!
127
+ return if @destroyed || @plan_handle.nil?
128
+
129
+ CuFFTBindings.cufftDestroy(@plan_handle)
130
+ @plan_handle = nil
131
+ @destroyed = true
132
+ end
133
+
134
+ # Check if plan has been destroyed
135
+ # @return [Boolean]
136
+ def destroyed?
137
+ @destroyed
138
+ end
139
+
140
+ # @return [String]
141
+ def to_s
142
+ status = @destroyed ? "destroyed" : "active"
143
+ "FFTPlan(shape=#{@shape}, dtype=#{@dtype}, #{@transform_type}, #{@direction}, #{status})"
144
+ end
145
+
146
+ private
147
+
148
+ # Create the cuFFT plan
149
+ def create_plan!
150
+ CuFFTBindings.ensure_loaded!
151
+
152
+ plan_ptr = FFI::MemoryPointer.new(:pointer)
153
+ fft_type = determine_fft_type
154
+
155
+ status = case ndim
156
+ when 1
157
+ CuFFTBindings.cufftPlan1d(plan_ptr, @shape[0], fft_type, @batch)
158
+ when 2
159
+ CuFFTBindings.cufftPlan2d(plan_ptr, @shape[0], @shape[1], fft_type)
160
+ when 3
161
+ CuFFTBindings.cufftPlan3d(plan_ptr, @shape[0], @shape[1], @shape[2], fft_type)
162
+ else
163
+ raise DimensionError, "FFT plan supports 1D, 2D, or 3D only, got #{ndim}D"
164
+ end
165
+
166
+ CuFFTBindings.check_status!(status, "Create FFT plan")
167
+
168
+ @plan_handle = plan_ptr.read_pointer
169
+
170
+ ObjectSpace.define_finalizer(self, self.class.release_finalizer(@plan_handle))
171
+ end
172
+
173
+ # Determine cuFFT type constant
174
+ # @return [Integer]
175
+ def determine_fft_type
176
+ case @transform_type
177
+ when :c2c
178
+ case @dtype
179
+ when :complex64 then CuFFTBindings::CUFFT_C2C
180
+ when :complex128 then CuFFTBindings::CUFFT_Z2Z
181
+ else raise UnsupportedDTypeError.new(@dtype, operation: "FFT C2C")
182
+ end
183
+ when :r2c
184
+ case @dtype
185
+ when :float32 then CuFFTBindings::CUFFT_R2C
186
+ when :float64 then CuFFTBindings::CUFFT_D2Z
187
+ else raise UnsupportedDTypeError.new(@dtype, operation: "FFT R2C")
188
+ end
189
+ when :c2r
190
+ case @dtype
191
+ when :complex64 then CuFFTBindings::CUFFT_C2R
192
+ when :complex128 then CuFFTBindings::CUFFT_Z2D
193
+ else raise UnsupportedDTypeError.new(@dtype, operation: "FFT C2R")
194
+ end
195
+ else
196
+ raise InvalidOperationError, "Unknown transform type: #{@transform_type}"
197
+ end
198
+ end
199
+
200
+ # Validate input array
201
+ # For batched FFT, input can have shape [batch, *plan_shape] or [*plan_shape]
202
+ def validate_input!(input)
203
+ raise ArgumentError, "Expected NvArray, got #{input.class}" unless input.is_a?(NvArray)
204
+
205
+ # For batched execution, input can have extra leading dimensions
206
+ # The trailing dimensions must match the plan shape
207
+ plan_dims = @shape.size
208
+ input_dims = input.shape.size
209
+
210
+ if input_dims == plan_dims
211
+ # Direct match - single FFT or batch=1
212
+ unless input.shape == @shape
213
+ raise DimensionError, "Shape mismatch: expected #{@shape}, got #{input.shape}"
214
+ end
215
+ elsif input_dims == plan_dims + 1
216
+ # Batched input: [batch, *plan_shape]
217
+ trailing_shape = input.shape[-plan_dims..]
218
+ unless trailing_shape == @shape
219
+ raise DimensionError, "Shape mismatch: expected trailing dims #{@shape}, got #{trailing_shape}"
220
+ end
221
+ # Verify batch count matches plan batch
222
+ expected_batch = input.shape[0]
223
+ if @batch > 1 && expected_batch != @batch
224
+ raise DimensionError, "Batch mismatch: plan has #{@batch}, input has #{expected_batch}"
225
+ end
226
+ else
227
+ raise DimensionError, "Shape mismatch: expected #{plan_dims} or #{plan_dims + 1} dims, got #{input_dims}"
228
+ end
229
+
230
+ # Validate dtype
231
+ case @transform_type
232
+ when :c2c, :c2r
233
+ unless input.dtype == @dtype
234
+ raise ArgumentError, "Dtype mismatch: expected #{@dtype}, got #{input.dtype}"
235
+ end
236
+ when :r2c
237
+ real_dtype = DType.real_dtype(@dtype)
238
+ unless input.dtype == real_dtype || input.dtype == @dtype
239
+ raise ArgumentError, "Dtype mismatch: expected #{real_dtype}, got #{input.dtype}"
240
+ end
241
+ end
242
+ end
243
+
244
+ # Execute C2C transform
245
+ def execute_c2c(input, output, direction)
246
+ case @dtype
247
+ when :complex64
248
+ CuFFTBindings.cufftExecC2C(@plan_handle, input.device_ptr, output.device_ptr, direction)
249
+ when :complex128
250
+ CuFFTBindings.cufftExecZ2Z(@plan_handle, input.device_ptr, output.device_ptr, direction)
251
+ end
252
+ end
253
+
254
+ # Execute R2C transform
255
+ def execute_r2c(input, output)
256
+ real_dtype = DType.real_dtype(@dtype)
257
+ case real_dtype
258
+ when :float32
259
+ CuFFTBindings.cufftExecR2C(@plan_handle, input.device_ptr, output.device_ptr)
260
+ when :float64
261
+ CuFFTBindings.cufftExecD2Z(@plan_handle, input.device_ptr, output.device_ptr)
262
+ end
263
+ end
264
+
265
+ # Execute C2R transform
266
+ def execute_c2r(input, output)
267
+ case @dtype
268
+ when :complex64
269
+ CuFFTBindings.cufftExecC2R(@plan_handle, input.device_ptr, output.device_ptr)
270
+ when :complex128
271
+ CuFFTBindings.cufftExecZ2D(@plan_handle, input.device_ptr, output.device_ptr)
272
+ end
273
+ end
274
+
275
+ class << self
276
+ # Create finalizer for plan cleanup
277
+ # @param handle [FFI::Pointer] Plan handle
278
+ # @return [Proc]
279
+ def release_finalizer(handle)
280
+ proc do
281
+ CuFFTBindings.ensure_loaded!
282
+ CuFFTBindings.cufftDestroy(handle)
283
+ end
284
+ end
285
+ end
286
+ end
287
+ end
288
+ end