ignis 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +15 -0
  3. data/lib/ignis.rb +94 -0
  4. data/lib/nnw/platform.rb +304 -0
  5. data/lib/nnw/shared/event_bus.rb +240 -0
  6. data/lib/nnw/shared/ffi_loader.rb +63 -0
  7. data/lib/nnw/shared/memory_contract.rb +204 -0
  8. data/lib/nnw/shared/nv_array.rb +710 -0
  9. data/lib/nnw/shared/recovery_protocol.rb +307 -0
  10. data/lib/nvruby/configuration.rb +217 -0
  11. data/lib/nvruby/cuda/device.rb +275 -0
  12. data/lib/nvruby/cuda/device_props.rb +202 -0
  13. data/lib/nvruby/cuda/graph.rb +265 -0
  14. data/lib/nvruby/cuda/graph_bindings.rb +119 -0
  15. data/lib/nvruby/cuda/library_loader.rb +285 -0
  16. data/lib/nvruby/cuda/memory.rb +410 -0
  17. data/lib/nvruby/cuda/runtime_api.rb +804 -0
  18. data/lib/nvruby/cuda/stream.rb +234 -0
  19. data/lib/nvruby/dtype.rb +139 -0
  20. data/lib/nvruby/epilogues.rb +438 -0
  21. data/lib/nvruby/errors.rb +303 -0
  22. data/lib/nvruby/half.rb +97 -0
  23. data/lib/nvruby/jit/compiled_kernel.rb +80 -0
  24. data/lib/nvruby/jit/compiler.rb +231 -0
  25. data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
  26. data/lib/nvruby/jit/kernel.rb +240 -0
  27. data/lib/nvruby/jit/kernel_module.rb +133 -0
  28. data/lib/nvruby/jit/kernels/activations.rb +179 -0
  29. data/lib/nvruby/jit/kernels/attention.rb +504 -0
  30. data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
  31. data/lib/nvruby/jit/kernels/loss.rb +213 -0
  32. data/lib/nvruby/jit/kernels/normalization.rb +200 -0
  33. data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
  34. data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
  35. data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
  36. data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
  37. data/lib/nvruby/linalg/epilog.rb +67 -0
  38. data/lib/nvruby/linalg/matmul.rb +247 -0
  39. data/lib/nvruby/linalg/matmul_plan.rb +229 -0
  40. data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
  41. data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
  42. data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
  43. data/lib/nvruby/memory/device_memory_resource.rb +106 -0
  44. data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
  45. data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
  46. data/lib/nvruby/memory/stats.rb +107 -0
  47. data/lib/nvruby/memory.rb +124 -0
  48. data/lib/nvruby/version.rb +5 -0
  49. metadata +108 -0
@@ -0,0 +1,295 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module LinAlg
7
+ # cuBLAS library FFI bindings
8
+ module CuBLASBindings
9
+ extend FFI::Library
10
+
11
+ # cuBLAS operation types
12
+ CUBLAS_OP_N = 0 # No transpose
13
+ CUBLAS_OP_T = 1 # Transpose
14
+ CUBLAS_OP_C = 2 # Conjugate transpose
15
+
16
+ # cuBLAS fill modes
17
+ CUBLAS_FILL_MODE_LOWER = 0
18
+ CUBLAS_FILL_MODE_UPPER = 1
19
+ CUBLAS_FILL_MODE_FULL = 2
20
+
21
+ # cuBLAS diagonal types
22
+ CUBLAS_DIAG_NON_UNIT = 0
23
+ CUBLAS_DIAG_UNIT = 1
24
+
25
+ # cuBLAS side modes
26
+ CUBLAS_SIDE_LEFT = 0
27
+ CUBLAS_SIDE_RIGHT = 1
28
+
29
+ # cuBLAS pointer modes
30
+ CUBLAS_POINTER_MODE_HOST = 0
31
+ CUBLAS_POINTER_MODE_DEVICE = 1
32
+
33
+ # cuBLAS atomics modes
34
+ CUBLAS_ATOMICS_NOT_ALLOWED = 0
35
+ CUBLAS_ATOMICS_ALLOWED = 1
36
+
37
+ # cuBLAS compute types
38
+ CUBLAS_COMPUTE_16F = 64
39
+ CUBLAS_COMPUTE_16F_PEDANTIC = 65
40
+ CUBLAS_COMPUTE_32F = 68
41
+ CUBLAS_COMPUTE_32F_PEDANTIC = 69
42
+ CUBLAS_COMPUTE_32F_FAST_16F = 74
43
+ CUBLAS_COMPUTE_32F_FAST_16BF = 75
44
+ CUBLAS_COMPUTE_32F_FAST_TF32 = 77
45
+ CUBLAS_COMPUTE_64F = 70
46
+ CUBLAS_COMPUTE_64F_PEDANTIC = 71
47
+ CUBLAS_COMPUTE_32I = 72
48
+ CUBLAS_COMPUTE_32I_PEDANTIC = 73
49
+
50
+ # cuBLASLt epilogue types
51
+ CUBLASLT_EPILOGUE_DEFAULT = 1
52
+ CUBLASLT_EPILOGUE_RELU = 2
53
+ CUBLASLT_EPILOGUE_RELU_AUX = 258
54
+ CUBLASLT_EPILOGUE_BIAS = 4
55
+ CUBLASLT_EPILOGUE_RELU_BIAS = 6
56
+ CUBLASLT_EPILOGUE_RELU_AUX_BIAS = 262
57
+ CUBLASLT_EPILOGUE_DRELU = 520
58
+ CUBLASLT_EPILOGUE_DRELU_BGRAD = 536
59
+ CUBLASLT_EPILOGUE_GELU = 32
60
+ CUBLASLT_EPILOGUE_GELU_AUX = 288
61
+ CUBLASLT_EPILOGUE_GELU_BIAS = 36
62
+ CUBLASLT_EPILOGUE_GELU_AUX_BIAS = 292
63
+ CUBLASLT_EPILOGUE_DGELU = 544
64
+ CUBLASLT_EPILOGUE_DGELU_BGRAD = 560
65
+ CUBLASLT_EPILOGUE_BGRADA = 576
66
+ CUBLASLT_EPILOGUE_BGRADB = 592
67
+
68
+ # cuBLAS Math Modes for Tensor Core control
69
+ CUBLAS_DEFAULT_MATH = 0
70
+ CUBLAS_TENSOR_OP_MATH = 1
71
+ CUBLAS_PEDANTIC_MATH = 2
72
+ CUBLAS_TF32_TENSOR_OP_MATH = 3
73
+ CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION = 16
74
+
75
+ @loaded = false
76
+ @handle = nil
77
+ @lt_handle = nil
78
+
79
+ class << self
80
+ # @return [FFI::Pointer, nil] cuBLAS handle
81
+ attr_accessor :handle
82
+
83
+ # @return [FFI::Pointer, nil] cuBLASLt handle
84
+ attr_accessor :lt_handle
85
+
86
+ # Ensure cuBLAS is loaded
87
+ # @return [void]
88
+ def ensure_loaded!
89
+ return if @loaded
90
+
91
+ CUDA::LibraryLoader.load_library(:cublas)
92
+
93
+ cuda_bin = Ignis.configuration.cuda_bin_path
94
+ if cuda_bin
95
+ ffi_lib Dir.glob(File.join(cuda_bin, "cublas64_*.dll")).max
96
+ else
97
+ ffi_lib "cublas64_12"
98
+ end
99
+
100
+ attach_cublas_functions!
101
+ initialize_cublas!
102
+
103
+ @loaded = true
104
+ end
105
+
106
+ # Get or create cuBLAS handle
107
+ # @return [FFI::Pointer]
108
+ def get_handle
109
+ ensure_loaded!
110
+ @handle
111
+ end
112
+
113
+ # Finalize cuBLAS
114
+ # @return [void]
115
+ def finalize!
116
+ return unless @handle
117
+
118
+ cublasDestroy_v2(@handle)
119
+ @handle = nil
120
+ @loaded = false
121
+ end
122
+
123
+ private
124
+
125
+ # rubocop:disable Metrics/MethodLength
126
+ def attach_cublas_functions!
127
+ # Handle management
128
+ attach_function :cublasCreate_v2, [:pointer], :int
129
+ attach_function :cublasDestroy_v2, [:pointer], :int
130
+ attach_function :cublasSetStream_v2, [:pointer, :pointer], :int
131
+ attach_function :cublasGetStream_v2, [:pointer, :pointer], :int
132
+ attach_function :cublasSetPointerMode_v2, [:pointer, :int], :int
133
+ attach_function :cublasGetPointerMode_v2, [:pointer, :pointer], :int
134
+ attach_function :cublasSetMathMode, [:pointer, :int], :int
135
+ attach_function :cublasGetMathMode, [:pointer, :pointer], :int
136
+
137
+ # Level 1 BLAS - Vector operations
138
+ attach_function :cublasSscal_v2, [:pointer, :int, :pointer, :pointer, :int], :int
139
+ attach_function :cublasDscal_v2, [:pointer, :int, :pointer, :pointer, :int], :int
140
+ attach_function :cublasSaxpy_v2, [:pointer, :int, :pointer, :pointer, :int, :pointer, :int], :int
141
+ attach_function :cublasDaxpy_v2, [:pointer, :int, :pointer, :pointer, :int, :pointer, :int], :int
142
+ attach_function :cublasSdot_v2, [:pointer, :int, :pointer, :int, :pointer, :int, :pointer], :int
143
+ attach_function :cublasDdot_v2, [:pointer, :int, :pointer, :int, :pointer, :int, :pointer], :int
144
+ attach_function :cublasSnrm2_v2, [:pointer, :int, :pointer, :int, :pointer], :int
145
+ attach_function :cublasDnrm2_v2, [:pointer, :int, :pointer, :int, :pointer], :int
146
+ attach_function :cublasSasum_v2, [:pointer, :int, :pointer, :int, :pointer], :int
147
+ attach_function :cublasDasum_v2, [:pointer, :int, :pointer, :int, :pointer], :int
148
+ attach_function :cublasIsamax_v2, [:pointer, :int, :pointer, :int, :pointer], :int
149
+ attach_function :cublasIdamax_v2, [:pointer, :int, :pointer, :int, :pointer], :int
150
+ attach_function :cublasIsamin_v2, [:pointer, :int, :pointer, :int, :pointer], :int
151
+ attach_function :cublasIdamin_v2, [:pointer, :int, :pointer, :int, :pointer], :int
152
+ attach_function :cublasScopy_v2, [:pointer, :int, :pointer, :int, :pointer, :int], :int
153
+ attach_function :cublasDcopy_v2, [:pointer, :int, :pointer, :int, :pointer, :int], :int
154
+ attach_function :cublasSswap_v2, [:pointer, :int, :pointer, :int, :pointer, :int], :int
155
+ attach_function :cublasDswap_v2, [:pointer, :int, :pointer, :int, :pointer, :int], :int
156
+
157
+ # Level 2 BLAS - Matrix-Vector operations
158
+ attach_function :cublasSgemv_v2, [
159
+ :pointer, :int, :int, :int,
160
+ :pointer, :pointer, :int, :pointer, :int,
161
+ :pointer, :pointer, :int
162
+ ], :int
163
+ attach_function :cublasDgemv_v2, [
164
+ :pointer, :int, :int, :int,
165
+ :pointer, :pointer, :int, :pointer, :int,
166
+ :pointer, :pointer, :int
167
+ ], :int
168
+
169
+ # Level 3 BLAS - Matrix-Matrix operations
170
+ attach_function :cublasSgemm_v2, [
171
+ :pointer, :int, :int,
172
+ :int, :int, :int,
173
+ :pointer, :pointer, :int,
174
+ :pointer, :int,
175
+ :pointer, :pointer, :int
176
+ ], :int
177
+ attach_function :cublasDgemm_v2, [
178
+ :pointer, :int, :int,
179
+ :int, :int, :int,
180
+ :pointer, :pointer, :int,
181
+ :pointer, :int,
182
+ :pointer, :pointer, :int
183
+ ], :int
184
+ attach_function :cublasCgemm_v2, [
185
+ :pointer, :int, :int,
186
+ :int, :int, :int,
187
+ :pointer, :pointer, :int,
188
+ :pointer, :int,
189
+ :pointer, :pointer, :int
190
+ ], :int
191
+ attach_function :cublasZgemm_v2, [
192
+ :pointer, :int, :int,
193
+ :int, :int, :int,
194
+ :pointer, :pointer, :int,
195
+ :pointer, :int,
196
+ :pointer, :pointer, :int
197
+ ], :int
198
+
199
+ # Strided batched GEMM
200
+ attach_function :cublasSgemmStridedBatched, [
201
+ :pointer, :int, :int,
202
+ :int, :int, :int,
203
+ :pointer, :pointer, :int, :long_long,
204
+ :pointer, :int, :long_long,
205
+ :pointer, :pointer, :int, :long_long,
206
+ :int
207
+ ], :int
208
+ attach_function :cublasDgemmStridedBatched, [
209
+ :pointer, :int, :int,
210
+ :int, :int, :int,
211
+ :pointer, :pointer, :int, :long_long,
212
+ :pointer, :int, :long_long,
213
+ :pointer, :pointer, :int, :long_long,
214
+ :int
215
+ ], :int
216
+
217
+ # TRSM - Triangular solve
218
+ attach_function :cublasStrsm_v2, [
219
+ :pointer, :int, :int, :int, :int,
220
+ :int, :int, :pointer, :pointer, :int, :pointer, :int
221
+ ], :int
222
+ attach_function :cublasDtrsm_v2, [
223
+ :pointer, :int, :int, :int, :int,
224
+ :int, :int, :pointer, :pointer, :int, :pointer, :int
225
+ ], :int
226
+
227
+ # SYRK - Symmetric rank-k update
228
+ attach_function :cublasSsyrk_v2, [
229
+ :pointer, :int, :int, :int, :int,
230
+ :pointer, :pointer, :int, :pointer, :pointer, :int
231
+ ], :int
232
+ attach_function :cublasDsyrk_v2, [
233
+ :pointer, :int, :int, :int, :int,
234
+ :pointer, :pointer, :int, :pointer, :pointer, :int
235
+ ], :int
236
+
237
+ # cublasGemmEx - Mixed precision GEMM with Tensor Core support
238
+ # Signature: cublasGemmEx(handle, transa, transb, m, n, k,
239
+ # alpha, A, Atype, lda, B, Btype, ldb,
240
+ # beta, C, Ctype, ldc, computeType, algo)
241
+ attach_function :cublasGemmEx, [
242
+ :pointer, # handle
243
+ :int, # transa
244
+ :int, # transb
245
+ :int, # m
246
+ :int, # n
247
+ :int, # k
248
+ :pointer, # alpha
249
+ :pointer, # A
250
+ :int, # Atype (cudaDataType_t)
251
+ :int, # lda
252
+ :pointer, # B
253
+ :int, # Btype (cudaDataType_t)
254
+ :int, # ldb
255
+ :pointer, # beta
256
+ :pointer, # C
257
+ :int, # Ctype (cudaDataType_t)
258
+ :int, # ldc
259
+ :int, # computeType (cublasComputeType_t)
260
+ :int # algo (cublasGemmAlgo_t)
261
+ ], :int
262
+ end
263
+ # rubocop:enable Metrics/MethodLength
264
+
265
+ # Initialize cuBLAS handle
266
+ def initialize_cublas!
267
+ handle_ptr = FFI::MemoryPointer.new(:pointer)
268
+ status = cublasCreate_v2(handle_ptr)
269
+
270
+ raise CuBLASError, status unless status.zero?
271
+
272
+ @handle = handle_ptr.read_pointer
273
+
274
+ # Enable Tensor Core math mode for maximum FP16/TF32 performance
275
+ math_status = cublasSetMathMode(@handle, CUBLAS_TENSOR_OP_MATH)
276
+ if math_status.zero?
277
+ Ignis.logger.info("cuBLAS initialized successfully with Tensor Core math enabled")
278
+ else
279
+ Ignis.logger.info("cuBLAS initialized successfully (Tensor Core math mode unavailable)")
280
+ end
281
+ end
282
+ end
283
+
284
+ # Check cuBLAS status and raise error if not success
285
+ # @param status [Integer] cuBLAS status code
286
+ # @param context [String] Context for error message
287
+ # @return [void]
288
+ def self.check_status!(status, context = "cuBLAS operation")
289
+ return if status.zero?
290
+
291
+ raise CuBLASError, status
292
+ end
293
+ end
294
+ end
295
+ end
@@ -0,0 +1,342 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ffi"
4
+
5
+ module Ignis
6
+ module LinAlg
7
+ # cuBLASLt (Light) library FFI bindings for advanced GEMM optimization
8
+ #
9
+ # cuBLASLt provides:
10
+ # - Descriptor-based matrix multiplication
11
+ # - Heuristic algorithm selection
12
+ # - Workspace optimization
13
+ # - Custom epilog operations (bias, activation fusion)
14
+ module CuBLASLtBindings
15
+ extend FFI::Library
16
+
17
+ # ==========================================================================
18
+ # cuBLASLt Constants
19
+ # ==========================================================================
20
+
21
+ # Matrix layout order
22
+ CUBLASLT_ORDER_COL = 0
23
+ CUBLASLT_ORDER_ROW = 1
24
+ CUBLASLT_ORDER_COL32 = 2
25
+ CUBLASLT_ORDER_COL4_4R2_8C = 3
26
+ CUBLASLT_ORDER_COL32_2R_4R4 = 4
27
+
28
+ # Matmul descriptor attributes
29
+ CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0
30
+ CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1
31
+ CUBLASLT_MATMUL_DESC_POINTER_MODE = 2
32
+ CUBLASLT_MATMUL_DESC_TRANSA = 3
33
+ CUBLASLT_MATMUL_DESC_TRANSB = 4
34
+ CUBLASLT_MATMUL_DESC_TRANSC = 5
35
+ CUBLASLT_MATMUL_DESC_EPILOGUE = 6
36
+ CUBLASLT_MATMUL_DESC_BIAS_POINTER = 7
37
+ CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 8
38
+ CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 9
39
+ CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 10
40
+ CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 11
41
+ CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 12
42
+ CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 13
43
+ CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 14
44
+ CUBLASLT_MATMUL_DESC_FAST_ACCUM = 15
45
+
46
+ # Matrix layout attributes
47
+ CUBLASLT_MATRIX_LAYOUT_TYPE = 0
48
+ CUBLASLT_MATRIX_LAYOUT_ORDER = 1
49
+ CUBLASLT_MATRIX_LAYOUT_ROWS = 2
50
+ CUBLASLT_MATRIX_LAYOUT_COLS = 3
51
+ CUBLASLT_MATRIX_LAYOUT_LD = 4
52
+ CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5
53
+ CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6
54
+ CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7
55
+
56
+ # Matmul preference attributes
57
+ CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0
58
+ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1
59
+ CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3
60
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5
61
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6
62
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7
63
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8
64
+ CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9
65
+ CUBLASLT_MATMUL_PREF_IMPL_MASK = 10
66
+ CUBLASLT_MATMUL_PREF_SM_COUNT_TARGET = 14
67
+
68
+ # Search modes
69
+ CUBLASLT_SEARCH_BEST_FIT = 0
70
+ CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1
71
+
72
+ # Numerical implementation flags for IMPL_MASK
73
+ CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA = 0x01
74
+ CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA = 0x02
75
+ CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA = 0x04
76
+ CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA = 0x08
77
+ CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK = 0x0E
78
+
79
+ # Reduction schemes
80
+ CUBLASLT_REDUCTION_SCHEME_NONE = 0
81
+ CUBLASLT_REDUCTION_SCHEME_INPLACE = 1
82
+ CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2
83
+ CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4
84
+ CUBLASLT_REDUCTION_SCHEME_MASK = 7
85
+
86
+ # Epilogue operations
87
+ CUBLASLT_EPILOGUE_DEFAULT = 1
88
+ CUBLASLT_EPILOGUE_RELU = 2
89
+ CUBLASLT_EPILOGUE_BIAS = 4
90
+ CUBLASLT_EPILOGUE_RELU_BIAS = 6
91
+ CUBLASLT_EPILOGUE_GELU = 32
92
+ CUBLASLT_EPILOGUE_GELU_BIAS = 36
93
+
94
+ # CUDA data types (for compatibility)
95
+ CUDA_R_16F = 2 # FP16
96
+ CUDA_R_32F = 0 # FP32
97
+ CUDA_R_64F = 1 # FP64
98
+ CUDA_R_16BF = 14 # BF16
99
+ CUDA_R_8F_E4M3 = 28 # FP8 E4M3
100
+ CUDA_R_8F_E5M2 = 29 # FP8 E5M2
101
+
102
+ # ==========================================================================
103
+ # Heuristic result structure
104
+ # ==========================================================================
105
+
106
+ # cublasLtMatmulHeuristicResult_t
107
+ class MatmulHeuristicResult < FFI::Struct
108
+ layout :algo, [:uint8, 64], # cublasLtMatmulAlgo_t (opaque, 64 bytes)
109
+ :workspaceSize, :size_t,
110
+ :state, :int, # cublasStatus_t
111
+ :wavesCount, :float,
112
+ :reserved, [:int, 4]
113
+ end
114
+
115
+ # ==========================================================================
116
+ # Library state
117
+ # ==========================================================================
118
+
119
+ @loaded = false
120
+ @lt_handle = nil
121
+ @workspace_ptr = nil
122
+ @workspace_size = 0
123
+
124
+ class << self
125
+ # @return [FFI::Pointer, nil] cuBLASLt handle
126
+ attr_accessor :lt_handle
127
+
128
+ # @return [FFI::Pointer, nil] Workspace memory
129
+ attr_accessor :workspace_ptr
130
+
131
+ # @return [Integer] Workspace size in bytes
132
+ attr_accessor :workspace_size
133
+
134
+ # Ensure cuBLASLt is loaded
135
+ # @return [void]
136
+ def ensure_loaded!
137
+ return if @loaded
138
+
139
+ cuda_bin = Ignis.configuration.cuda_bin_path
140
+ if cuda_bin
141
+ dll_path = Dir.glob(File.join(cuda_bin, "cublasLt64_*.dll")).max
142
+ ffi_lib dll_path if dll_path
143
+ else
144
+ ffi_lib "cublasLt64_12"
145
+ end
146
+
147
+ attach_cublaslt_functions!
148
+ initialize_cublaslt!
149
+
150
+ @loaded = true
151
+ end
152
+
153
+ # Get or create cuBLASLt handle
154
+ # @return [FFI::Pointer]
155
+ def get_handle
156
+ ensure_loaded!
157
+ @lt_handle
158
+ end
159
+
160
+ # Get or allocate workspace
161
+ # @param min_size [Integer] Minimum workspace size in bytes
162
+ # @return [FFI::Pointer] Workspace pointer
163
+ def get_workspace(min_size = 256 * 1024 * 1024)
164
+ ensure_loaded!
165
+
166
+ if @workspace_ptr.nil? || @workspace_size < min_size
167
+ # Free existing workspace
168
+ CUDA::RuntimeAPI.cudaFree(@workspace_ptr) if @workspace_ptr
169
+
170
+ # Allocate new workspace
171
+ ptr_ptr = FFI::MemoryPointer.new(:pointer)
172
+ status = CUDA::RuntimeAPI.cudaMalloc(ptr_ptr, min_size)
173
+ if status.zero?
174
+ @workspace_ptr = ptr_ptr.read_pointer
175
+ @workspace_size = min_size
176
+ Ignis.logger.info { "cuBLASLt workspace allocated: #{min_size / 1024 / 1024}MB" }
177
+ else
178
+ Ignis.logger.warn { "Failed to allocate cuBLASLt workspace: #{status}" }
179
+ @workspace_ptr = FFI::Pointer::NULL
180
+ @workspace_size = 0
181
+ end
182
+ end
183
+
184
+ @workspace_ptr
185
+ end
186
+
187
+ # Finalize cuBLASLt
188
+ # @return [void]
189
+ def finalize!
190
+ if @workspace_ptr && !@workspace_ptr.null?
191
+ CUDA::RuntimeAPI.cudaFree(@workspace_ptr) rescue nil
192
+ @workspace_ptr = nil
193
+ @workspace_size = 0
194
+ end
195
+
196
+ if @lt_handle
197
+ cublasLtDestroy(@lt_handle) rescue nil
198
+ @lt_handle = nil
199
+ end
200
+
201
+ @loaded = false
202
+ end
203
+
204
+ private
205
+
206
+ # rubocop:disable Metrics/MethodLength, Metrics/AbcSize
207
+ def attach_cublaslt_functions!
208
+ # Handle management
209
+ attach_function :cublasLtCreate, [:pointer], :int
210
+ attach_function :cublasLtDestroy, [:pointer], :int
211
+
212
+ # Matmul descriptor
213
+ attach_function :cublasLtMatmulDescCreate, [:pointer, :int, :int], :int
214
+ attach_function :cublasLtMatmulDescDestroy, [:pointer], :int
215
+ attach_function :cublasLtMatmulDescSetAttribute, [:pointer, :int, :pointer, :size_t], :int
216
+ attach_function :cublasLtMatmulDescGetAttribute, [:pointer, :int, :pointer, :size_t, :pointer], :int
217
+
218
+ # Matrix layout descriptor
219
+ attach_function :cublasLtMatrixLayoutCreate, [:pointer, :int, :uint64, :uint64, :int64], :int
220
+ attach_function :cublasLtMatrixLayoutDestroy, [:pointer], :int
221
+ attach_function :cublasLtMatrixLayoutSetAttribute, [:pointer, :int, :pointer, :size_t], :int
222
+ attach_function :cublasLtMatrixLayoutGetAttribute, [:pointer, :int, :pointer, :size_t, :pointer], :int
223
+
224
+ # Matmul preference
225
+ attach_function :cublasLtMatmulPreferenceCreate, [:pointer], :int
226
+ attach_function :cublasLtMatmulPreferenceDestroy, [:pointer], :int
227
+ attach_function :cublasLtMatmulPreferenceSetAttribute, [:pointer, :int, :pointer, :size_t], :int
228
+ attach_function :cublasLtMatmulPreferenceGetAttribute, [:pointer, :int, :pointer, :size_t, :pointer], :int
229
+
230
+ # Algorithm heuristic
231
+ attach_function :cublasLtMatmulAlgoGetHeuristic, [
232
+ :pointer, # lightHandle
233
+ :pointer, # operationDesc
234
+ :pointer, # Adesc
235
+ :pointer, # Bdesc
236
+ :pointer, # Cdesc
237
+ :pointer, # Ddesc
238
+ :pointer, # preference
239
+ :int, # requestedAlgoCount
240
+ :pointer, # heuristicResultsArray (array of MatmulHeuristicResult)
241
+ :pointer # returnAlgoCount
242
+ ], :int
243
+
244
+ # Matrix multiplication
245
+ attach_function :cublasLtMatmul, [
246
+ :pointer, # lightHandle
247
+ :pointer, # computeDesc
248
+ :pointer, # alpha
249
+ :pointer, # A
250
+ :pointer, # Adesc
251
+ :pointer, # B
252
+ :pointer, # Bdesc
253
+ :pointer, # beta
254
+ :pointer, # C
255
+ :pointer, # Cdesc
256
+ :pointer, # D
257
+ :pointer, # Ddesc
258
+ :pointer, # algo (cublasLtMatmulAlgo_t)
259
+ :pointer, # workspace
260
+ :size_t, # workspaceSizeInBytes
261
+ :pointer # stream
262
+ ], :int
263
+ end
264
+ # rubocop:enable Metrics/MethodLength, Metrics/AbcSize
265
+
266
+ def initialize_cublaslt!
267
+ handle_ptr = FFI::MemoryPointer.new(:pointer)
268
+ status = cublasLtCreate(handle_ptr)
269
+
270
+ if status.zero?
271
+ @lt_handle = handle_ptr.read_pointer
272
+ Ignis.logger.info { "cuBLASLt initialized successfully" }
273
+ else
274
+ raise CuBLASError.new("Failed to initialize cuBLASLt", code: status)
275
+ end
276
+ end
277
+ end
278
+
279
+ # ==========================================================================
280
+ # Helper Methods
281
+ # ==========================================================================
282
+
283
+ # Convert dtype symbol to CUDA data type
284
+ # @param dtype [Symbol] Data type (:float16, :float32, :float64)
285
+ # @return [Integer] CUDA data type constant
286
+ def self.dtype_to_cuda_type(dtype)
287
+ case dtype
288
+ when :float16, :half then CUDA_R_16F
289
+ when :float32, :float then CUDA_R_32F
290
+ when :float64, :double then CUDA_R_64F
291
+ when :bfloat16 then CUDA_R_16BF
292
+ else CUDA_R_32F
293
+ end
294
+ end
295
+
296
+ # Get compute type for dtype
297
+ # @param dtype [Symbol] Data type
298
+ # @return [Integer] Compute type constant
299
+ def self.compute_type_for_dtype(dtype)
300
+ case dtype
301
+ when :float16, :half
302
+ CuBLASBindings::CUBLAS_COMPUTE_32F_FAST_16F
303
+ when :float32, :float
304
+ CuBLASBindings::CUBLAS_COMPUTE_32F_FAST_TF32
305
+ when :float64, :double
306
+ CuBLASBindings::CUBLAS_COMPUTE_64F
307
+ else
308
+ CuBLASBindings::CUBLAS_COMPUTE_32F
309
+ end
310
+ end
311
+
312
+ # Get scale type for dtype (type of alpha/beta scalars)
313
+ #
314
+ # Per NVIDIA cuBLAS documentation, scaleType must match the accumulator type:
315
+ # - CUBLAS_COMPUTE_32F_FAST_16F requires CUDA_R_32F scale type
316
+ # - CUBLAS_COMPUTE_32F_FAST_TF32 requires CUDA_R_32F scale type
317
+ # - CUBLAS_COMPUTE_64F requires CUDA_R_64F scale type
318
+ #
319
+ # @param dtype [Symbol] Data type of input matrices
320
+ # @return [Integer] CUDA scale type constant for alpha/beta
321
+ def self.scale_type_for_dtype(dtype)
322
+ case dtype
323
+ when :float64, :double
324
+ CUDA_R_64F
325
+ else
326
+ # FP16, BF16, FP32, TF32 all use FP32 accumulation and require FP32 scale type
327
+ CUDA_R_32F
328
+ end
329
+ end
330
+
331
+ # Check status and raise error if not success
332
+ # @param status [Integer] cuBLAS status code
333
+ # @param context [String] Context for error message
334
+ # @return [void]
335
+ def self.check_status!(status, context = "cuBLASLt operation")
336
+ return if status.zero?
337
+
338
+ raise CuBLASError.new("#{context}: status=#{status}", code: status)
339
+ end
340
+ end
341
+ end
342
+ end
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module LinAlg
5
+ # Epilog operations for fused GEMM
6
+ module Epilog
7
+ # Available epilog types
8
+ TYPES = {
9
+ default: CuBLASBindings::CUBLASLT_EPILOGUE_DEFAULT,
10
+ relu: CuBLASBindings::CUBLASLT_EPILOGUE_RELU,
11
+ relu_aux: CuBLASBindings::CUBLASLT_EPILOGUE_RELU_AUX,
12
+ bias: CuBLASBindings::CUBLASLT_EPILOGUE_BIAS,
13
+ relu_bias: CuBLASBindings::CUBLASLT_EPILOGUE_RELU_BIAS,
14
+ relu_aux_bias: CuBLASBindings::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
15
+ drelu: CuBLASBindings::CUBLASLT_EPILOGUE_DRELU,
16
+ drelu_bgrad: CuBLASBindings::CUBLASLT_EPILOGUE_DRELU_BGRAD,
17
+ gelu: CuBLASBindings::CUBLASLT_EPILOGUE_GELU,
18
+ gelu_aux: CuBLASBindings::CUBLASLT_EPILOGUE_GELU_AUX,
19
+ gelu_bias: CuBLASBindings::CUBLASLT_EPILOGUE_GELU_BIAS,
20
+ gelu_aux_bias: CuBLASBindings::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
21
+ dgelu: CuBLASBindings::CUBLASLT_EPILOGUE_DGELU,
22
+ dgelu_bgrad: CuBLASBindings::CUBLASLT_EPILOGUE_DGELU_BGRAD,
23
+ bgrada: CuBLASBindings::CUBLASLT_EPILOGUE_BGRADA,
24
+ bgradb: CuBLASBindings::CUBLASLT_EPILOGUE_BGRADB
25
+ }.freeze
26
+
27
+ class << self
28
+ # Get epilog type constant
29
+ # @param type [Symbol] Epilog type
30
+ # @return [Integer] cuBLASLt epilog constant
31
+ def get(type)
32
+ constant = TYPES[type]
33
+ raise ArgumentError, "Unknown epilog type: #{type}" unless constant
34
+
35
+ constant
36
+ end
37
+
38
+ # List available epilog types
39
+ # @return [Array<Symbol>]
40
+ def available
41
+ TYPES.keys
42
+ end
43
+
44
+ # Check if epilog type supports bias
45
+ # @param type [Symbol] Epilog type
46
+ # @return [Boolean]
47
+ def supports_bias?(type)
48
+ %i[bias relu_bias relu_aux_bias gelu_bias gelu_aux_bias].include?(type)
49
+ end
50
+
51
+ # Check if epilog type supports auxiliary output
52
+ # @param type [Symbol] Epilog type
53
+ # @return [Boolean]
54
+ def supports_aux?(type)
55
+ %i[relu_aux relu_aux_bias gelu_aux gelu_aux_bias].include?(type)
56
+ end
57
+
58
+ # Check if epilog type is for backward pass
59
+ # @param type [Symbol] Epilog type
60
+ # @return [Boolean]
61
+ def backward?(type)
62
+ %i[drelu drelu_bgrad dgelu dgelu_bgrad bgrada bgradb].include?(type)
63
+ end
64
+ end
65
+ end
66
+ end
67
+ end