ignis 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +15 -0
  3. data/lib/ignis.rb +94 -0
  4. data/lib/nnw/platform.rb +304 -0
  5. data/lib/nnw/shared/event_bus.rb +240 -0
  6. data/lib/nnw/shared/ffi_loader.rb +63 -0
  7. data/lib/nnw/shared/memory_contract.rb +204 -0
  8. data/lib/nnw/shared/nv_array.rb +710 -0
  9. data/lib/nnw/shared/recovery_protocol.rb +307 -0
  10. data/lib/nvruby/configuration.rb +217 -0
  11. data/lib/nvruby/cuda/device.rb +275 -0
  12. data/lib/nvruby/cuda/device_props.rb +202 -0
  13. data/lib/nvruby/cuda/graph.rb +265 -0
  14. data/lib/nvruby/cuda/graph_bindings.rb +119 -0
  15. data/lib/nvruby/cuda/library_loader.rb +285 -0
  16. data/lib/nvruby/cuda/memory.rb +410 -0
  17. data/lib/nvruby/cuda/runtime_api.rb +804 -0
  18. data/lib/nvruby/cuda/stream.rb +234 -0
  19. data/lib/nvruby/dtype.rb +139 -0
  20. data/lib/nvruby/epilogues.rb +438 -0
  21. data/lib/nvruby/errors.rb +303 -0
  22. data/lib/nvruby/half.rb +97 -0
  23. data/lib/nvruby/jit/compiled_kernel.rb +80 -0
  24. data/lib/nvruby/jit/compiler.rb +231 -0
  25. data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
  26. data/lib/nvruby/jit/kernel.rb +240 -0
  27. data/lib/nvruby/jit/kernel_module.rb +133 -0
  28. data/lib/nvruby/jit/kernels/activations.rb +179 -0
  29. data/lib/nvruby/jit/kernels/attention.rb +504 -0
  30. data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
  31. data/lib/nvruby/jit/kernels/loss.rb +213 -0
  32. data/lib/nvruby/jit/kernels/normalization.rb +200 -0
  33. data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
  34. data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
  35. data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
  36. data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
  37. data/lib/nvruby/linalg/epilog.rb +67 -0
  38. data/lib/nvruby/linalg/matmul.rb +247 -0
  39. data/lib/nvruby/linalg/matmul_plan.rb +229 -0
  40. data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
  41. data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
  42. data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
  43. data/lib/nvruby/memory/device_memory_resource.rb +106 -0
  44. data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
  45. data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
  46. data/lib/nvruby/memory/stats.rb +107 -0
  47. data/lib/nvruby/memory.rb +124 -0
  48. data/lib/nvruby/version.rb +5 -0
  49. metadata +108 -0
@@ -0,0 +1,303 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ # Base error class for all Ignis errors
5
+ class Error < StandardError; end
6
+
7
+ # Raised when a CUDA operation fails
8
+ class CudaError < Error
9
+ # @return [Integer] CUDA error code
10
+ attr_reader :cuda_code
11
+
12
+ # @param message [String] Error message
13
+ # @param cuda_code [Integer, nil] CUDA error code
14
+ def initialize(message, cuda_code: nil)
15
+ @cuda_code = cuda_code
16
+ super(cuda_code ? "#{message} (CUDA error #{cuda_code})" : message)
17
+ end
18
+ end
19
+
20
+ # Raised when CUDA runtime encounters an error
21
+ class CudaRuntimeError < CudaError; end
22
+
23
+ # Raised when cuBLAS operation fails
24
+ class CuBLASError < CudaError
25
+ STATUSES = {
26
+ 0 => :success,
27
+ 1 => :not_initialized,
28
+ 3 => :alloc_failed,
29
+ 7 => :invalid_value,
30
+ 8 => :arch_mismatch,
31
+ 11 => :mapping_error,
32
+ 13 => :execution_failed,
33
+ 14 => :internal_error,
34
+ 15 => :not_supported,
35
+ 16 => :license_error
36
+ }.freeze
37
+
38
+ # @param status [Integer] cuBLAS status code
39
+ def initialize(status)
40
+ status_name = STATUSES[status] || :unknown
41
+ super("cuBLAS operation failed: #{status_name}", cuda_code: status)
42
+ end
43
+ end
44
+
45
+ # Raised when cuFFT operation fails
46
+ class CuFFTError < CudaError
47
+ STATUSES = {
48
+ 0 => :success,
49
+ 1 => :invalid_plan,
50
+ 2 => :alloc_failed,
51
+ 3 => :invalid_type,
52
+ 4 => :invalid_value,
53
+ 5 => :internal_error,
54
+ 6 => :exec_failed,
55
+ 7 => :setup_failed,
56
+ 8 => :invalid_size,
57
+ 9 => :unaligned_data,
58
+ 10 => :incomplete_parameter_list,
59
+ 11 => :invalid_device,
60
+ 12 => :parse_error,
61
+ 13 => :no_workspace,
62
+ 14 => :not_implemented,
63
+ 15 => :license_error,
64
+ 16 => :not_supported
65
+ }.freeze
66
+
67
+ # @param status [Integer] cuFFT status code
68
+ def initialize(status)
69
+ status_name = STATUSES[status] || :unknown
70
+ super("cuFFT operation failed: #{status_name}", cuda_code: status)
71
+ end
72
+ end
73
+
74
+ # Raised when cuRAND operation fails
75
+ class CuRANDError < CudaError
76
+ STATUSES = {
77
+ 0 => :success,
78
+ 100 => :version_mismatch,
79
+ 101 => :not_initialized,
80
+ 102 => :allocation_failed,
81
+ 103 => :type_error,
82
+ 104 => :out_of_range,
83
+ 105 => :length_not_multiple,
84
+ 106 => :double_precision_required,
85
+ 201 => :launch_failure,
86
+ 202 => :preexisting_failure,
87
+ 203 => :initialization_failed,
88
+ 204 => :arch_mismatch,
89
+ 999 => :internal_error
90
+ }.freeze
91
+
92
+ # @param status [Integer] cuRAND status code
93
+ def initialize(status)
94
+ status_name = STATUSES[status] || :unknown
95
+ super("cuRAND operation failed: #{status_name}", cuda_code: status)
96
+ end
97
+ end
98
+
99
+ # Raised when cuSPARSE operation fails
100
+ class CuSPARSEError < CudaError
101
+ STATUSES = {
102
+ 0 => :success,
103
+ 1 => :not_initialized,
104
+ 2 => :alloc_failed,
105
+ 3 => :invalid_value,
106
+ 4 => :arch_mismatch,
107
+ 5 => :mapping_error,
108
+ 6 => :execution_failed,
109
+ 7 => :internal_error,
110
+ 8 => :matrix_type_not_supported,
111
+ 9 => :zero_pivot,
112
+ 10 => :not_supported,
113
+ 11 => :insufficient_resources
114
+ }.freeze
115
+
116
+ # @param status [Integer] cuSPARSE status code
117
+ def initialize(status)
118
+ status_name = STATUSES[status] || :unknown
119
+ super("cuSPARSE operation failed: #{status_name}", cuda_code: status)
120
+ end
121
+ end
122
+
123
+ # Raised when cuSOLVER operation fails
124
+ class CuSolverError < CudaError
125
+ STATUSES = {
126
+ 0 => :success,
127
+ 1 => :not_initialized,
128
+ 2 => :alloc_failed,
129
+ 3 => :invalid_value,
130
+ 4 => :arch_mismatch,
131
+ 5 => :mapping_error,
132
+ 6 => :execution_failed,
133
+ 7 => :internal_error,
134
+ 8 => :matrix_type_not_supported,
135
+ 9 => :not_supported,
136
+ 10 => :zero_pivot,
137
+ 11 => :invalid_license,
138
+ 12 => :irs_params_not_initialized,
139
+ 13 => :irs_params_invalid,
140
+ 14 => :irs_params_invalid_prec,
141
+ 15 => :irs_params_invalid_refine,
142
+ 16 => :irs_params_invalid_maxiter,
143
+ 20 => :irs_internal_error,
144
+ 21 => :irs_not_supported,
145
+ 22 => :irs_out_of_range,
146
+ 23 => :irs_nrhs_not_supported_for_refine_gmres,
147
+ 25 => :irs_infos_not_initialized
148
+ }.freeze
149
+
150
+ # @return [String, nil] Context where error occurred
151
+ attr_reader :context
152
+
153
+ # @param message [String] Error message
154
+ # @param cusolver_code [Integer, nil] cuSOLVER status code
155
+ # @param context [String, nil] Context for error
156
+ def initialize(message, cusolver_code: nil, context: nil)
157
+ @context = context
158
+ status_name = cusolver_code ? (STATUSES[cusolver_code] || :unknown) : nil
159
+
160
+ full_message = message
161
+ full_message = "#{context}: #{full_message}" if context
162
+ full_message = "#{full_message} (#{status_name})" if status_name
163
+
164
+ super(full_message, cuda_code: cusolver_code)
165
+ end
166
+ end
167
+
168
+ # Raised when memory allocation fails
169
+ class MemoryError < Error; end
170
+
171
+ # Raised when device memory allocation fails
172
+ class DeviceMemoryError < MemoryError; end
173
+
174
+ # Raised when host memory allocation fails
175
+ class HostMemoryError < MemoryError; end
176
+
177
+ # Raised when an invalid operation is attempted
178
+ class InvalidOperationError < Error; end
179
+
180
+ # Raised when required CUDA library is not found
181
+ class LibraryNotFoundError < Error
182
+ # @return [String] Library name that was not found
183
+ attr_reader :library_name
184
+
185
+ # @param library_name [String] Name of the missing library
186
+ def initialize(library_name)
187
+ @library_name = library_name
188
+ super("Required CUDA library not found: #{library_name}")
189
+ end
190
+ end
191
+
192
+ # Raised when no CUDA-capable device is available
193
+ class NoDeviceError < Error
194
+ def initialize
195
+ super("No CUDA-capable device is available")
196
+ end
197
+ end
198
+
199
+ # Raised when array dimensions are incompatible
200
+ class DimensionError < Error; end
201
+
202
+ # Raised when data type is not supported
203
+ class UnsupportedDTypeError < Error
204
+ # @return [Symbol] The unsupported dtype
205
+ attr_reader :dtype
206
+
207
+ # @param dtype [Symbol] The unsupported data type
208
+ # @param operation [String, nil] The operation that doesn't support this dtype
209
+ def initialize(dtype, operation: nil)
210
+ @dtype = dtype
211
+ message = "Unsupported data type: #{dtype}"
212
+ message += " for operation #{operation}" if operation
213
+ super(message)
214
+ end
215
+ end
216
+
217
+ # Raised when autotuning fails
218
+ class AutotuneError < Error; end
219
+
220
+ # Raised when plan creation fails
221
+ class PlanError < Error; end
222
+
223
+ # Raised when NVRTC compilation fails
224
+ class NVRTCError < CudaError
225
+ STATUSES = {
226
+ 0 => :success,
227
+ 1 => :out_of_memory,
228
+ 2 => :program_creation_failure,
229
+ 3 => :invalid_input,
230
+ 4 => :invalid_program,
231
+ 5 => :invalid_option,
232
+ 6 => :compilation_error,
233
+ 7 => :builtin_operation_failure,
234
+ 8 => :no_name_expressions_after_compilation,
235
+ 9 => :no_lowered_names_before_compilation,
236
+ 10 => :name_expression_not_valid,
237
+ 11 => :internal_error
238
+ }.freeze
239
+
240
+ # @return [String, nil] Compilation log with error details
241
+ attr_reader :compilation_log
242
+
243
+ # @return [String, nil] Context where error occurred
244
+ attr_reader :context
245
+
246
+ # @param status [Integer] NVRTC result code
247
+ # @param compilation_log [String, nil] Compilation error log
248
+ # @param context [String, nil] Context for error
249
+ def initialize(status, compilation_log: nil, context: nil)
250
+ @compilation_log = compilation_log
251
+ @context = context
252
+ status_name = STATUSES[status] || :unknown
253
+
254
+ message = "NVRTC error: #{status_name}"
255
+ message = "#{context}: #{message}" if context
256
+ message = "#{message}\nCompilation log:\n#{compilation_log}" if compilation_log && !compilation_log.empty?
257
+
258
+ super(message, cuda_code: status)
259
+ end
260
+ end
261
+
262
+ # Raised when CUDA Driver API operation fails
263
+ class CudaDriverError < CudaError
264
+ STATUSES = {
265
+ 0 => :success,
266
+ 1 => :invalid_value,
267
+ 2 => :out_of_memory,
268
+ 3 => :not_initialized,
269
+ 4 => :deinitialized,
270
+ 100 => :no_device,
271
+ 101 => :invalid_device,
272
+ 200 => :invalid_image,
273
+ 201 => :invalid_context,
274
+ 209 => :no_binary_for_gpu,
275
+ 218 => :invalid_ptx,
276
+ 300 => :invalid_source,
277
+ 301 => :file_not_found,
278
+ 400 => :invalid_handle,
279
+ 500 => :not_found,
280
+ 600 => :not_ready,
281
+ 700 => :illegal_address,
282
+ 701 => :launch_out_of_resources,
283
+ 702 => :launch_timeout,
284
+ 719 => :launch_failed,
285
+ 999 => :unknown
286
+ }.freeze
287
+
288
+ # @return [String, nil] Context where error occurred
289
+ attr_reader :context
290
+
291
+ # @param status [Integer] CUDA Driver result code
292
+ # @param context [String, nil] Context for error
293
+ def initialize(status, context: nil)
294
+ @context = context
295
+ status_name = STATUSES[status] || :unknown
296
+
297
+ message = "CUDA Driver error: #{status_name}"
298
+ message = "#{context}: #{message}" if context
299
+
300
+ super(message, cuda_code: status)
301
+ end
302
+ end
303
+ end
@@ -0,0 +1,97 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ # Pure-Ruby IEEE-754 half-precision (binary16) and bfloat16 <-> float32 conversion.
5
+ #
6
+ # Ruby and FFI have no native 16-bit float type, so NvArray stores fp16/bf16 as
7
+ # raw uint16 bit patterns. These helpers convert to/from Ruby Floats with correct
8
+ # round-to-nearest-even rounding and proper subnormal / overflow / inf / NaN handling.
9
+ #
10
+ # This is the single source of truth for half conversion across BOTH NvArray
11
+ # classes (Ignis::NvArray and Ignis::Shared::NvArray) and the safetensors codec,
12
+ # so the math cannot drift between them.
13
+ module Half
14
+ module_function
15
+
16
+ # Encode a Float as IEEE-754 binary16 (fp16) bit pattern.
17
+ # @param value [Numeric]
18
+ # @return [Integer] 16-bit unsigned (0..0xFFFF)
19
+ def f32_to_f16(value)
20
+ bits = [value.to_f].pack("e").unpack1("V") # float32 little-endian bit pattern
21
+ sign = (bits >> 16) & 0x8000
22
+ exp = (bits >> 23) & 0xFF
23
+ mant = bits & 0x7FFFFF
24
+
25
+ # Inf / NaN
26
+ return sign | (mant.zero? ? 0x7C00 : 0x7E00) if exp == 0xFF
27
+
28
+ e = exp - 127 + 15
29
+
30
+ if e >= 0x1F
31
+ # Overflow -> signed Inf
32
+ sign | 0x7C00
33
+ elsif e <= 0
34
+ # Subnormal or zero
35
+ return sign if e < -10 # too small to represent even as a subnormal
36
+
37
+ m = mant | 0x800000 # restore the implicit leading 1
38
+ shift = 14 - e # 14..24
39
+ half = m >> shift
40
+ rem = m & ((1 << shift) - 1)
41
+ halfway = 1 << (shift - 1)
42
+ half += 1 if rem > halfway || (rem == halfway && (half & 1) == 1) # round half to even
43
+ sign | half
44
+ else
45
+ # Normal
46
+ half = sign | (e << 10) | (mant >> 13)
47
+ rem = mant & 0x1FFF
48
+ # round half to even; a carry correctly propagates mantissa -> exponent (incl. -> Inf)
49
+ half += 1 if rem > 0x1000 || (rem == 0x1000 && (half & 1) == 1)
50
+ half & 0xFFFF
51
+ end
52
+ end
53
+
54
+ # Decode an IEEE-754 binary16 (fp16) bit pattern to a Ruby Float.
55
+ # @param bits [Integer] 16-bit unsigned
56
+ # @return [Float]
57
+ def f16_to_f32(bits)
58
+ sign = (bits >> 15) & 0x1
59
+ exp = (bits >> 10) & 0x1F
60
+ mant = bits & 0x3FF
61
+
62
+ if exp.zero?
63
+ return sign.zero? ? 0.0 : -0.0 if mant.zero?
64
+
65
+ val = (mant / 1024.0) * (2.0**-14)
66
+ elsif exp == 0x1F
67
+ return Float::NAN unless mant.zero?
68
+
69
+ return sign.zero? ? Float::INFINITY : -Float::INFINITY
70
+ else
71
+ val = (1.0 + mant / 1024.0) * (2.0**(exp - 15))
72
+ end
73
+
74
+ sign.zero? ? val : -val
75
+ end
76
+
77
+ # Encode a Float as bfloat16 (the upper 16 bits of float32), round-to-nearest-even.
78
+ # @param value [Numeric]
79
+ # @return [Integer] 16-bit unsigned (0..0xFFFF)
80
+ def f32_to_bf16(value)
81
+ bits = [value.to_f].pack("e").unpack1("V")
82
+ # NaN: keep sign+exponent, force a non-zero mantissa (quiet NaN)
83
+ return ((bits >> 16) | 0x0040) & 0xFFFF if (bits & 0x7FFFFFFF) > 0x7F800000
84
+
85
+ lsb = (bits >> 16) & 1
86
+ bits += 0x7FFF + lsb # round half to even
87
+ (bits >> 16) & 0xFFFF
88
+ end
89
+
90
+ # Decode a bfloat16 bit pattern to a Ruby Float.
91
+ # @param bits [Integer] 16-bit unsigned
92
+ # @return [Float]
93
+ def bf16_to_f32(bits)
94
+ [(bits & 0xFFFF) << 16].pack("V").unpack1("e")
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,80 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module JIT
5
+ # Represents compiled CUDA code (CUBIN binary)
6
+ # Device-agnostic representation that can be loaded onto any compatible GPU
7
+ # Thread-safe immutable object suitable for caching
8
+ class CompiledKernel
9
+ # @return [String] The compiled CUBIN binary data
10
+ attr_reader :cubin_data
11
+
12
+ # @return [Integer] Size of the CUBIN data in bytes
13
+ attr_reader :cubin_size
14
+
15
+ # @return [Integer] Target compute capability (e.g., 89 for sm_89)
16
+ attr_reader :compute_capability
17
+
18
+ # @return [String] The original source code (for debugging)
19
+ attr_reader :source_code
20
+
21
+ # @return [String] Kernel function name
22
+ attr_reader :kernel_name
23
+
24
+ # @return [Array<String>] Compilation options used
25
+ attr_reader :compile_options
26
+
27
+ # @return [Time] Compilation timestamp
28
+ attr_reader :compiled_at
29
+
30
+ # Create a new CompiledKernel
31
+ # @param cubin_data [String] The compiled CUBIN binary
32
+ # @param compute_capability [Integer] Target compute capability
33
+ # @param kernel_name [String] Kernel function name
34
+ # @param source_code [String, nil] Original source code
35
+ # @param compile_options [Array<String>] Options used for compilation
36
+ def initialize(cubin_data:, compute_capability:, kernel_name:, source_code: nil, compile_options: [])
37
+ @cubin_data = cubin_data.freeze
38
+ @cubin_size = cubin_data.bytesize
39
+ @compute_capability = compute_capability
40
+ @kernel_name = kernel_name.freeze
41
+ @source_code = source_code&.freeze
42
+ @compile_options = compile_options.map(&:freeze).freeze
43
+ @compiled_at = Time.now.freeze
44
+ freeze
45
+ end
46
+
47
+ # Load this compiled kernel onto a specific device
48
+ # @param device_id [Integer] Device index to load onto
49
+ # @return [KernelModule] The loaded module with function handle
50
+ # @raise [CudaDriverError] If loading fails
51
+ def load(device_id: 0)
52
+ Ignis.set_device(device_id)
53
+ KernelModule.new(self, device_id: device_id)
54
+ end
55
+
56
+ # Get a string representation
57
+ # @return [String]
58
+ def to_s
59
+ "#<Ignis::JIT::CompiledKernel #{@kernel_name} sm_#{@compute_capability} #{@cubin_size} bytes>"
60
+ end
61
+
62
+ # Get detailed inspection
63
+ # @return [String]
64
+ def inspect
65
+ "#<Ignis::JIT::CompiledKernel:0x#{object_id.to_s(16)} " \
66
+ "kernel=#{@kernel_name.inspect} " \
67
+ "sm=#{@compute_capability} " \
68
+ "size=#{@cubin_size} " \
69
+ "compiled_at=#{@compiled_at.strftime('%Y-%m-%d %H:%M:%S')}>"
70
+ end
71
+
72
+ # Check if compatible with a given compute capability
73
+ # @param target_cc [Integer] Target compute capability
74
+ # @return [Boolean] True if this kernel can run on target
75
+ def compatible_with?(target_cc)
76
+ target_cc >= @compute_capability
77
+ end
78
+ end
79
+ end
80
+ end
@@ -0,0 +1,231 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "digest"
4
+
5
+ module Ignis
6
+ module JIT
7
+ # JIT Compiler with multi-level caching
8
+ # Provides runtime compilation of CUDA C++ source to executable kernels
9
+ # Based on nvmath-python caching strategy:
10
+ # - Thread-local cache for fast path (no locking)
11
+ # - Shared cache with Mutex for cross-thread reuse
12
+ # - Cache key: source code hash + compute capability
13
+ class Compiler
14
+ # Default compilation options
15
+ DEFAULT_OPTIONS = [
16
+ "--fmad=true",
17
+ "--std=c++17",
18
+ "-default-device"
19
+ ].freeze
20
+
21
+ # Thread-local storage key for kernel pointer cache
22
+ THREAD_LOCAL_CACHE_KEY = :nvruby_jit_kernel_cache
23
+
24
+ # Shared caches (guarded by mutex)
25
+ @compiled_code_cache = {} # cc -> source_hash -> CompiledKernel
26
+ @kernel_module_cache = {} # device_id -> source_hash -> KernelModule
27
+ @cache_mutex = Mutex.new
28
+
29
+ class << self
30
+ # Compile CUDA source code and return an executable Kernel
31
+ # @param source [String] CUDA C++ source code
32
+ # @param kernel_name [String] Name of the kernel function
33
+ # @param device_id [Integer] Target device (for compute capability)
34
+ # @param options [Array<String>] Additional compilation options
35
+ # @return [Kernel] Executable kernel
36
+ # @raise [NVRTCError] If compilation fails
37
+ def compile(source, kernel_name, device_id: 0, options: [])
38
+ source_hash = compute_source_hash(source, kernel_name)
39
+ cc = get_device_cc(device_id)
40
+
41
+ kernel_ptr = get_cached_kernel(device_id, source_hash)
42
+ return kernel_ptr if kernel_ptr
43
+
44
+ @cache_mutex.synchronize do
45
+ kernel_ptr = get_kernel_module_cached(source, source_hash, kernel_name, device_id, cc, options)
46
+ cache_kernel_pointer(device_id, source_hash, kernel_ptr)
47
+ kernel_ptr
48
+ end
49
+ end
50
+
51
+ # Compile to CompiledKernel without loading (for caching/serialization)
52
+ # @param source [String] CUDA C++ source code
53
+ # @param kernel_name [String] Name of the kernel function
54
+ # @param compute_capability [Integer] Target compute capability (e.g., 89)
55
+ # @param options [Array<String>] Additional compilation options
56
+ # @return [CompiledKernel] Compiled kernel (not yet loaded)
57
+ # @raise [NVRTCError] If compilation fails
58
+ def compile_only(source, kernel_name, compute_capability:, options: [])
59
+ compile_to_cubin(source, kernel_name, compute_capability, options)
60
+ end
61
+
62
+ # Clear all caches
63
+ # @return [void]
64
+ def clear_cache!
65
+ @cache_mutex.synchronize do
66
+ @compiled_code_cache.clear
67
+ @kernel_module_cache.clear
68
+ end
69
+ clear_thread_local_cache!
70
+ Ignis.logger.info("JIT compiler caches cleared")
71
+ end
72
+
73
+ # Clear thread-local cache only
74
+ # @return [void]
75
+ def clear_thread_local_cache!
76
+ Thread.current[THREAD_LOCAL_CACHE_KEY] = nil
77
+ end
78
+
79
+ # Get cache statistics
80
+ # @return [Hash] Statistics about cache usage
81
+ def cache_stats
82
+ @cache_mutex.synchronize do
83
+ compiled_count = @compiled_code_cache.values.sum { |h| h.size }
84
+ module_count = @kernel_module_cache.values.sum { |h| h.size }
85
+
86
+ {
87
+ compiled_kernels: compiled_count,
88
+ loaded_modules: module_count,
89
+ compute_capabilities: @compiled_code_cache.keys,
90
+ devices_with_modules: @kernel_module_cache.keys
91
+ }
92
+ end
93
+ end
94
+
95
+ # Get version information
96
+ # @return [Hash] NVRTC and driver version info
97
+ def version_info
98
+ {
99
+ nvrtc: NVRTCBindings.version,
100
+ cuda_driver: Ignis.cuda_version
101
+ }
102
+ end
103
+
104
+ private
105
+
106
+ # Compute a hash key for source code
107
+ # @param source [String]
108
+ # @param kernel_name [String]
109
+ # @return [String]
110
+ def compute_source_hash(source, kernel_name)
111
+ Digest::SHA256.hexdigest("#{kernel_name}:#{source}")
112
+ end
113
+
114
+ # Get device compute capability
115
+ # @param device_id [Integer]
116
+ # @return [Integer] Compute capability as integer (e.g., 89)
117
+ def get_device_cc(device_id)
118
+ thread_cache = Thread.current[:nvruby_device_cc_cache] ||= {}
119
+ return thread_cache[device_id] if thread_cache.key?(device_id)
120
+
121
+ major, minor = DriverAPIBindings.get_device_compute_capability(device_id)
122
+ cc = major * 10 + minor
123
+ thread_cache[device_id] = cc
124
+ cc
125
+ end
126
+
127
+ # Get cached kernel from thread-local cache
128
+ # @param device_id [Integer]
129
+ # @param source_hash [String]
130
+ # @return [Kernel, nil]
131
+ def get_cached_kernel(device_id, source_hash)
132
+ cache = Thread.current[THREAD_LOCAL_CACHE_KEY]
133
+ return nil unless cache
134
+
135
+ device_cache = cache[device_id]
136
+ return nil unless device_cache
137
+
138
+ device_cache[source_hash]
139
+ end
140
+
141
+ # Cache kernel in thread-local cache
142
+ # @param device_id [Integer]
143
+ # @param source_hash [String]
144
+ # @param kernel [Kernel]
145
+ # @return [void]
146
+ def cache_kernel_pointer(device_id, source_hash, kernel)
147
+ cache = Thread.current[THREAD_LOCAL_CACHE_KEY] ||= {}
148
+ device_cache = cache[device_id] ||= {}
149
+ device_cache[source_hash] = kernel
150
+ end
151
+
152
+ # Get or create kernel module from shared cache
153
+ # @param source [String]
154
+ # @param source_hash [String]
155
+ # @param kernel_name [String]
156
+ # @param device_id [Integer]
157
+ # @param cc [Integer]
158
+ # @param options [Array<String>]
159
+ # @return [Kernel]
160
+ def get_kernel_module_cached(source, source_hash, kernel_name, device_id, cc, options)
161
+ device_cache = @kernel_module_cache[device_id] ||= {}
162
+ kernel_module = device_cache[source_hash]
163
+
164
+ if kernel_module.nil? || kernel_module.destroyed?
165
+ compiled = get_compiled_code_cached(source, source_hash, kernel_name, cc, options)
166
+ kernel_module = compiled.load(device_id: device_id)
167
+ device_cache[source_hash] = kernel_module
168
+ Ignis.logger.debug("Loaded kernel module: #{kernel_name} on device #{device_id}")
169
+ end
170
+
171
+ kernel_module.to_kernel
172
+ end
173
+
174
+ # Get or create compiled code from shared cache
175
+ # @param source [String]
176
+ # @param source_hash [String]
177
+ # @param kernel_name [String]
178
+ # @param cc [Integer]
179
+ # @param options [Array<String>]
180
+ # @return [CompiledKernel]
181
+ def get_compiled_code_cached(source, source_hash, kernel_name, cc, options)
182
+ cc_cache = @compiled_code_cache[cc] ||= {}
183
+ compiled = cc_cache[source_hash]
184
+
185
+ if compiled.nil?
186
+ compiled = compile_to_cubin(source, kernel_name, cc, options)
187
+ cc_cache[source_hash] = compiled
188
+ Ignis.logger.info("Compiled kernel: #{kernel_name} for sm_#{cc}")
189
+ else
190
+ Ignis.logger.debug("Using cached compiled kernel: #{kernel_name} for sm_#{cc}")
191
+ end
192
+
193
+ compiled
194
+ end
195
+
196
+ # Compile source code to CUBIN
197
+ # @param source [String]
198
+ # @param kernel_name [String]
199
+ # @param cc [Integer]
200
+ # @param options [Array<String>]
201
+ # @return [CompiledKernel]
202
+ def compile_to_cubin(source, kernel_name, cc, options)
203
+ major = cc / 10
204
+ minor = cc % 10
205
+ arch_option = "--gpu-architecture=sm_#{major}#{minor}"
206
+
207
+ all_options = DEFAULT_OPTIONS + [arch_option] + options
208
+
209
+ Ignis.logger.debug("Compiling kernel #{kernel_name} with options: #{all_options}")
210
+
211
+ program = NVRTCBindings.create_program(source, name: "#{kernel_name}.cu")
212
+
213
+ begin
214
+ NVRTCBindings.compile_program(program, options: all_options)
215
+ cubin_data = NVRTCBindings.get_cubin(program)
216
+
217
+ CompiledKernel.new(
218
+ cubin_data: cubin_data,
219
+ compute_capability: cc,
220
+ kernel_name: kernel_name,
221
+ source_code: source,
222
+ compile_options: all_options
223
+ )
224
+ ensure
225
+ NVRTCBindings.destroy_program(program)
226
+ end
227
+ end
228
+ end
229
+ end
230
+ end
231
+ end