ignis 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +15 -0
  3. data/lib/ignis.rb +94 -0
  4. data/lib/nnw/platform.rb +304 -0
  5. data/lib/nnw/shared/event_bus.rb +240 -0
  6. data/lib/nnw/shared/ffi_loader.rb +63 -0
  7. data/lib/nnw/shared/memory_contract.rb +204 -0
  8. data/lib/nnw/shared/nv_array.rb +710 -0
  9. data/lib/nnw/shared/recovery_protocol.rb +307 -0
  10. data/lib/nvruby/configuration.rb +217 -0
  11. data/lib/nvruby/cuda/device.rb +275 -0
  12. data/lib/nvruby/cuda/device_props.rb +202 -0
  13. data/lib/nvruby/cuda/graph.rb +265 -0
  14. data/lib/nvruby/cuda/graph_bindings.rb +119 -0
  15. data/lib/nvruby/cuda/library_loader.rb +285 -0
  16. data/lib/nvruby/cuda/memory.rb +410 -0
  17. data/lib/nvruby/cuda/runtime_api.rb +804 -0
  18. data/lib/nvruby/cuda/stream.rb +234 -0
  19. data/lib/nvruby/dtype.rb +139 -0
  20. data/lib/nvruby/epilogues.rb +438 -0
  21. data/lib/nvruby/errors.rb +303 -0
  22. data/lib/nvruby/half.rb +97 -0
  23. data/lib/nvruby/jit/compiled_kernel.rb +80 -0
  24. data/lib/nvruby/jit/compiler.rb +231 -0
  25. data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
  26. data/lib/nvruby/jit/kernel.rb +240 -0
  27. data/lib/nvruby/jit/kernel_module.rb +133 -0
  28. data/lib/nvruby/jit/kernels/activations.rb +179 -0
  29. data/lib/nvruby/jit/kernels/attention.rb +504 -0
  30. data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
  31. data/lib/nvruby/jit/kernels/loss.rb +213 -0
  32. data/lib/nvruby/jit/kernels/normalization.rb +200 -0
  33. data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
  34. data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
  35. data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
  36. data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
  37. data/lib/nvruby/linalg/epilog.rb +67 -0
  38. data/lib/nvruby/linalg/matmul.rb +247 -0
  39. data/lib/nvruby/linalg/matmul_plan.rb +229 -0
  40. data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
  41. data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
  42. data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
  43. data/lib/nvruby/memory/device_memory_resource.rb +106 -0
  44. data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
  45. data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
  46. data/lib/nvruby/memory/stats.rb +107 -0
  47. data/lib/nvruby/memory.rb +124 -0
  48. data/lib/nvruby/version.rb +5 -0
  49. metadata +108 -0
@@ -0,0 +1,234 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'fiddle'
4
+
5
+ module Ignis
6
+ module CUDA
7
+ # CUDA stream for asynchronous operations.
8
+ #
9
+ # Refactored to use Fiddle-based RuntimeAPI methods.
10
+ class Stream
11
+ # @return [Fiddle::Pointer] Stream handle
12
+ attr_reader :handle
13
+
14
+ # Default stream singleton
15
+ DEFAULT = :default
16
+
17
+ # @param synchronous [Boolean] If true, creates a synchronous stream
18
+ def initialize(synchronous: false)
19
+ @synchronous = synchronous
20
+ @handle = create_stream
21
+ @destroyed = false
22
+
23
+ unless @synchronous
24
+ captured_handle = @handle
25
+ ObjectSpace.define_finalizer(self, self.class.release_finalizer(captured_handle))
26
+ end
27
+ end
28
+
29
+ # @return [Fiddle::Pointer] for interop
30
+ def to_ptr
31
+ @handle
32
+ end
33
+
34
+ # Check if stream is the default stream.
35
+ # @return [Boolean]
36
+ def default?
37
+ @handle.nil? || @handle.to_i.zero?
38
+ end
39
+
40
+ # @return [Boolean]
41
+ def destroyed?
42
+ @destroyed
43
+ end
44
+
45
+ # Synchronize the stream (wait for all operations to complete).
46
+ # @return [void]
47
+ def synchronize
48
+ return if @destroyed || default?
49
+
50
+ RuntimeAPI.ensure_loaded!
51
+ RuntimeAPI.stream_synchronize(@handle)
52
+ end
53
+
54
+ # Query if stream has completed.
55
+ # @return [Boolean]
56
+ def completed?
57
+ return true if @destroyed || default?
58
+
59
+ RuntimeAPI.ensure_loaded!
60
+ RuntimeAPI.stream_query(@handle)
61
+ end
62
+
63
+ # Destroy the stream.
64
+ # @return [void]
65
+ def destroy!
66
+ return if @destroyed || default?
67
+
68
+ RuntimeAPI.ensure_loaded!
69
+ RuntimeAPI.stream_destroy(@handle)
70
+
71
+ @destroyed = true
72
+ ObjectSpace.undefine_finalizer(self)
73
+ end
74
+
75
+ # @return [String]
76
+ def to_s
77
+ return 'CudaStream[default]' if default?
78
+ return 'CudaStream[destroyed]' if @destroyed
79
+
80
+ "CudaStream[0x#{@handle.to_i.to_s(16)}]"
81
+ end
82
+
83
+ class << self
84
+ # Get the default (null) stream.
85
+ # @return [Stream]
86
+ def default
87
+ @default_stream ||= DefaultStream.new
88
+ end
89
+
90
+ # Create a finalizer for stream cleanup.
91
+ # @param handle [Fiddle::Pointer]
92
+ # @return [Proc]
93
+ def release_finalizer(handle)
94
+ handle_addr = handle.to_i
95
+ proc do
96
+ next if handle_addr.zero?
97
+
98
+ begin
99
+ RuntimeAPI.ensure_loaded!
100
+ RuntimeAPI.stream_destroy(Fiddle::Pointer.new(handle_addr))
101
+ rescue StandardError
102
+ # Silently ignore errors during finalization
103
+ end
104
+ end
105
+ end
106
+ end
107
+
108
+ private
109
+
110
+ # Create a CUDA stream.
111
+ # @return [Fiddle::Pointer]
112
+ def create_stream
113
+ return Fiddle::Pointer.new(0) if @synchronous
114
+
115
+ RuntimeAPI.ensure_loaded!
116
+ RuntimeAPI.stream_create
117
+ end
118
+ end
119
+
120
+ # Default stream that wraps the null stream.
121
+ class DefaultStream < Stream
122
+ def initialize
123
+ @handle = Fiddle::Pointer.new(0)
124
+ @destroyed = false
125
+ @synchronous = true
126
+ end
127
+
128
+ def destroy!
129
+ # Cannot destroy default stream
130
+ end
131
+
132
+ def default?
133
+ true
134
+ end
135
+ end
136
+
137
+ # CUDA event for timing and synchronization.
138
+ #
139
+ # Refactored to use Fiddle-based RuntimeAPI methods.
140
+ class Event
141
+ # @return [Fiddle::Pointer] Event handle
142
+ attr_reader :handle
143
+
144
+ # @param blocking [Boolean] If true, CPU will block on synchronize
145
+ # @param disable_timing [Boolean] If true, timing is disabled for better performance
146
+ def initialize(blocking: false, disable_timing: false)
147
+ @blocking = blocking
148
+ @disable_timing = disable_timing
149
+ @handle = create_event
150
+ @destroyed = false
151
+
152
+ captured_handle = @handle
153
+ ObjectSpace.define_finalizer(self, self.class.release_finalizer(captured_handle))
154
+ end
155
+
156
+ # @return [Boolean]
157
+ def destroyed?
158
+ @destroyed
159
+ end
160
+
161
+ # Record the event on a stream.
162
+ # @param stream [Stream, nil] Stream to record on (nil for default)
163
+ # @return [void]
164
+ def record(stream: nil)
165
+ raise InvalidOperationError, 'Event has been destroyed' if @destroyed
166
+
167
+ RuntimeAPI.ensure_loaded!
168
+ stream_handle = stream&.to_ptr || Fiddle::Pointer.new(0)
169
+ RuntimeAPI.event_record(@handle, stream_handle)
170
+ end
171
+
172
+ # Wait for the event to complete.
173
+ # @return [void]
174
+ def synchronize
175
+ raise InvalidOperationError, 'Event has been destroyed' if @destroyed
176
+
177
+ RuntimeAPI.ensure_loaded!
178
+ RuntimeAPI.event_synchronize(@handle)
179
+ end
180
+
181
+ # @return [Fiddle::Pointer]
182
+ def to_ptr
183
+ @handle
184
+ end
185
+
186
+ # Calculate elapsed time between two events.
187
+ # @param start_event [Event] Start event
188
+ # @param end_event [Event] End event
189
+ # @return [Float] Elapsed time in milliseconds
190
+ def self.elapsed_time(start_event, end_event)
191
+ RuntimeAPI.ensure_loaded!
192
+ RuntimeAPI.event_elapsed_time(start_event.handle, end_event.handle)
193
+ end
194
+
195
+ # Destroy the event.
196
+ # @return [void]
197
+ def destroy!
198
+ return if @destroyed
199
+
200
+ RuntimeAPI.ensure_loaded!
201
+ RuntimeAPI.event_destroy(@handle)
202
+
203
+ @destroyed = true
204
+ ObjectSpace.undefine_finalizer(self)
205
+ end
206
+
207
+ class << self
208
+ # Create a finalizer for event cleanup.
209
+ # @param handle [Fiddle::Pointer]
210
+ # @return [Proc]
211
+ def release_finalizer(handle)
212
+ handle_addr = handle.to_i
213
+ proc do
214
+ begin
215
+ RuntimeAPI.ensure_loaded!
216
+ RuntimeAPI.event_destroy(Fiddle::Pointer.new(handle_addr))
217
+ rescue StandardError
218
+ # Silently ignore errors during finalization
219
+ end
220
+ end
221
+ end
222
+ end
223
+
224
+ private
225
+
226
+ # Create a CUDA event.
227
+ # @return [Fiddle::Pointer]
228
+ def create_event
229
+ RuntimeAPI.ensure_loaded!
230
+ RuntimeAPI.event_create
231
+ end
232
+ end
233
+ end
234
+ end
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ # Supported data types with their properties
5
+ module DType
6
+ # Data type definitions with byte size and FFI type
7
+ TYPES = {
8
+ float16: { bytes: 2, ffi: :uint16, cuda: 2, complex: false }, # CUDA_R_16F (stored as uint16)
9
+ bfloat16: { bytes: 2, ffi: :uint16, cuda: 14, complex: false }, # CUDA_R_16BF (stored as uint16)
10
+ float32: { bytes: 4, ffi: :float, cuda: 0, complex: false },
11
+ float64: { bytes: 8, ffi: :double, cuda: 1, complex: false },
12
+ int8: { bytes: 1, ffi: :int8, cuda: 3, complex: false },
13
+ int16: { bytes: 2, ffi: :int16, cuda: 8, complex: false },
14
+ int32: { bytes: 4, ffi: :int32, cuda: 10, complex: false },
15
+ int64: { bytes: 8, ffi: :int64, cuda: 11, complex: false },
16
+ uint8: { bytes: 1, ffi: :uint8, cuda: 17, complex: false },
17
+ uint16: { bytes: 2, ffi: :uint16, cuda: 18, complex: false },
18
+ uint32: { bytes: 4, ffi: :uint32, cuda: 19, complex: false },
19
+ uint64: { bytes: 8, ffi: :uint64, cuda: 20, complex: false },
20
+ complex64: { bytes: 8, ffi: :float, cuda: 4, complex: true }, # 2x float32
21
+ complex128: { bytes: 16, ffi: :double, cuda: 5, complex: true } # 2x float64
22
+ }.freeze
23
+
24
+ # cuBLAS data type constants (cudaDataType_t values)
25
+ CUBLAS_TYPES = {
26
+ float16: 2, # CUDA_R_16F
27
+ bfloat16: 14, # CUDA_R_16BF
28
+ float32: 0, # CUDA_R_32F
29
+ float64: 1, # CUDA_R_64F
30
+ complex64: 4, # CUDA_C_32F
31
+ complex128: 5, # CUDA_C_64F
32
+ int8: 3, # CUDA_R_8I
33
+ int32: 10 # CUDA_R_32I
34
+ }.freeze
35
+
36
+ class << self
37
+ # Get byte size for a dtype
38
+ # @param dtype [Symbol] Data type
39
+ # @return [Integer] Size in bytes
40
+ def byte_size(dtype)
41
+ info = TYPES[dtype]
42
+ raise UnsupportedDTypeError, dtype unless info
43
+
44
+ info[:bytes]
45
+ end
46
+
47
+ # Get FFI type for a dtype
48
+ # @param dtype [Symbol] Data type
49
+ # @return [Symbol] FFI type symbol
50
+ def ffi_type(dtype)
51
+ info = TYPES[dtype]
52
+ raise UnsupportedDTypeError, dtype unless info
53
+
54
+ info[:ffi]
55
+ end
56
+
57
+ # Check if dtype is complex
58
+ # @param dtype [Symbol] Data type
59
+ # @return [Boolean]
60
+ def complex?(dtype)
61
+ info = TYPES[dtype]
62
+ raise UnsupportedDTypeError, dtype unless info
63
+
64
+ info[:complex]
65
+ end
66
+
67
+ # Check if dtype is floating point
68
+ # @param dtype [Symbol] Data type
69
+ # @return [Boolean]
70
+ def float?(dtype)
71
+ %i[float16 bfloat16 float32 float64 complex64 complex128].include?(dtype)
72
+ end
73
+
74
+ # Check if dtype is integer
75
+ # @param dtype [Symbol] Data type
76
+ # @return [Boolean]
77
+ def integer?(dtype)
78
+ %i[int8 int16 int32 int64 uint8 uint16 uint32 uint64].include?(dtype)
79
+ end
80
+
81
+ # Check if dtype is signed
82
+ # @param dtype [Symbol] Data type
83
+ # @return [Boolean]
84
+ def signed?(dtype)
85
+ !%i[uint8 uint16 uint32 uint64].include?(dtype)
86
+ end
87
+
88
+ # Get cuBLAS type constant
89
+ # @param dtype [Symbol] Data type
90
+ # @return [Integer] cuBLAS type constant
91
+ def cublas_type(dtype)
92
+ type = CUBLAS_TYPES[dtype]
93
+ raise UnsupportedDTypeError.new(dtype, operation: "cuBLAS") unless type
94
+
95
+ type
96
+ end
97
+
98
+ # Validate dtype
99
+ # @param dtype [Symbol] Data type to validate
100
+ # @return [Symbol] Validated dtype
101
+ # @raise [UnsupportedDTypeError] If dtype is not valid
102
+ def validate!(dtype)
103
+ raise UnsupportedDTypeError, dtype unless TYPES.key?(dtype)
104
+
105
+ dtype
106
+ end
107
+
108
+ # List all supported dtypes
109
+ # @return [Array<Symbol>] All dtype symbols
110
+ def all
111
+ TYPES.keys
112
+ end
113
+
114
+ # Get real component dtype for complex types
115
+ # @param dtype [Symbol] Complex data type
116
+ # @return [Symbol] Real component dtype
117
+ def real_dtype(dtype)
118
+ case dtype
119
+ when :complex64 then :float32
120
+ when :complex128 then :float64
121
+ else dtype
122
+ end
123
+ end
124
+
125
+ # Get complex dtype for real types
126
+ # @param dtype [Symbol] Real data type
127
+ # @return [Symbol] Complex dtype
128
+ def complex_dtype(dtype)
129
+ case dtype
130
+ when :float32 then :complex64
131
+ when :float64 then :complex128
132
+ when :complex64, :complex128 then dtype
133
+ else
134
+ raise UnsupportedDTypeError.new(dtype, operation: "complex conversion")
135
+ end
136
+ end
137
+ end
138
+ end
139
+ end