ignis 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis.rb +94 -0
- data/lib/nnw/platform.rb +304 -0
- data/lib/nnw/shared/event_bus.rb +240 -0
- data/lib/nnw/shared/ffi_loader.rb +63 -0
- data/lib/nnw/shared/memory_contract.rb +204 -0
- data/lib/nnw/shared/nv_array.rb +710 -0
- data/lib/nnw/shared/recovery_protocol.rb +307 -0
- data/lib/nvruby/configuration.rb +217 -0
- data/lib/nvruby/cuda/device.rb +275 -0
- data/lib/nvruby/cuda/device_props.rb +202 -0
- data/lib/nvruby/cuda/graph.rb +265 -0
- data/lib/nvruby/cuda/graph_bindings.rb +119 -0
- data/lib/nvruby/cuda/library_loader.rb +285 -0
- data/lib/nvruby/cuda/memory.rb +410 -0
- data/lib/nvruby/cuda/runtime_api.rb +804 -0
- data/lib/nvruby/cuda/stream.rb +234 -0
- data/lib/nvruby/dtype.rb +139 -0
- data/lib/nvruby/epilogues.rb +438 -0
- data/lib/nvruby/errors.rb +303 -0
- data/lib/nvruby/half.rb +97 -0
- data/lib/nvruby/jit/compiled_kernel.rb +80 -0
- data/lib/nvruby/jit/compiler.rb +231 -0
- data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
- data/lib/nvruby/jit/kernel.rb +240 -0
- data/lib/nvruby/jit/kernel_module.rb +133 -0
- data/lib/nvruby/jit/kernels/activations.rb +179 -0
- data/lib/nvruby/jit/kernels/attention.rb +504 -0
- data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
- data/lib/nvruby/jit/kernels/loss.rb +213 -0
- data/lib/nvruby/jit/kernels/normalization.rb +200 -0
- data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
- data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
- data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
- data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
- data/lib/nvruby/linalg/epilog.rb +67 -0
- data/lib/nvruby/linalg/matmul.rb +247 -0
- data/lib/nvruby/linalg/matmul_plan.rb +229 -0
- data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
- data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
- data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
- data/lib/nvruby/memory/device_memory_resource.rb +106 -0
- data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
- data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
- data/lib/nvruby/memory/stats.rb +107 -0
- data/lib/nvruby/memory.rb +124 -0
- data/lib/nvruby/version.rb +5 -0
- metadata +108 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'fiddle'
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module CUDA
|
|
7
|
+
# CUDA stream for asynchronous operations.
|
|
8
|
+
#
|
|
9
|
+
# Refactored to use Fiddle-based RuntimeAPI methods.
|
|
10
|
+
class Stream
|
|
11
|
+
# @return [Fiddle::Pointer] Stream handle
|
|
12
|
+
attr_reader :handle
|
|
13
|
+
|
|
14
|
+
# Default stream singleton
|
|
15
|
+
DEFAULT = :default
|
|
16
|
+
|
|
17
|
+
# @param synchronous [Boolean] If true, creates a synchronous stream
|
|
18
|
+
def initialize(synchronous: false)
|
|
19
|
+
@synchronous = synchronous
|
|
20
|
+
@handle = create_stream
|
|
21
|
+
@destroyed = false
|
|
22
|
+
|
|
23
|
+
unless @synchronous
|
|
24
|
+
captured_handle = @handle
|
|
25
|
+
ObjectSpace.define_finalizer(self, self.class.release_finalizer(captured_handle))
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# @return [Fiddle::Pointer] for interop
|
|
30
|
+
def to_ptr
|
|
31
|
+
@handle
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Check if stream is the default stream.
|
|
35
|
+
# @return [Boolean]
|
|
36
|
+
def default?
|
|
37
|
+
@handle.nil? || @handle.to_i.zero?
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# @return [Boolean]
|
|
41
|
+
def destroyed?
|
|
42
|
+
@destroyed
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# Synchronize the stream (wait for all operations to complete).
|
|
46
|
+
# @return [void]
|
|
47
|
+
def synchronize
|
|
48
|
+
return if @destroyed || default?
|
|
49
|
+
|
|
50
|
+
RuntimeAPI.ensure_loaded!
|
|
51
|
+
RuntimeAPI.stream_synchronize(@handle)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Query if stream has completed.
|
|
55
|
+
# @return [Boolean]
|
|
56
|
+
def completed?
|
|
57
|
+
return true if @destroyed || default?
|
|
58
|
+
|
|
59
|
+
RuntimeAPI.ensure_loaded!
|
|
60
|
+
RuntimeAPI.stream_query(@handle)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Destroy the stream.
|
|
64
|
+
# @return [void]
|
|
65
|
+
def destroy!
|
|
66
|
+
return if @destroyed || default?
|
|
67
|
+
|
|
68
|
+
RuntimeAPI.ensure_loaded!
|
|
69
|
+
RuntimeAPI.stream_destroy(@handle)
|
|
70
|
+
|
|
71
|
+
@destroyed = true
|
|
72
|
+
ObjectSpace.undefine_finalizer(self)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# @return [String]
|
|
76
|
+
def to_s
|
|
77
|
+
return 'CudaStream[default]' if default?
|
|
78
|
+
return 'CudaStream[destroyed]' if @destroyed
|
|
79
|
+
|
|
80
|
+
"CudaStream[0x#{@handle.to_i.to_s(16)}]"
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
class << self
|
|
84
|
+
# Get the default (null) stream.
|
|
85
|
+
# @return [Stream]
|
|
86
|
+
def default
|
|
87
|
+
@default_stream ||= DefaultStream.new
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Create a finalizer for stream cleanup.
|
|
91
|
+
# @param handle [Fiddle::Pointer]
|
|
92
|
+
# @return [Proc]
|
|
93
|
+
def release_finalizer(handle)
|
|
94
|
+
handle_addr = handle.to_i
|
|
95
|
+
proc do
|
|
96
|
+
next if handle_addr.zero?
|
|
97
|
+
|
|
98
|
+
begin
|
|
99
|
+
RuntimeAPI.ensure_loaded!
|
|
100
|
+
RuntimeAPI.stream_destroy(Fiddle::Pointer.new(handle_addr))
|
|
101
|
+
rescue StandardError
|
|
102
|
+
# Silently ignore errors during finalization
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
private
|
|
109
|
+
|
|
110
|
+
# Create a CUDA stream.
|
|
111
|
+
# @return [Fiddle::Pointer]
|
|
112
|
+
def create_stream
|
|
113
|
+
return Fiddle::Pointer.new(0) if @synchronous
|
|
114
|
+
|
|
115
|
+
RuntimeAPI.ensure_loaded!
|
|
116
|
+
RuntimeAPI.stream_create
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# Default stream that wraps the null stream.
|
|
121
|
+
class DefaultStream < Stream
|
|
122
|
+
def initialize
|
|
123
|
+
@handle = Fiddle::Pointer.new(0)
|
|
124
|
+
@destroyed = false
|
|
125
|
+
@synchronous = true
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def destroy!
|
|
129
|
+
# Cannot destroy default stream
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
def default?
|
|
133
|
+
true
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# CUDA event for timing and synchronization.
|
|
138
|
+
#
|
|
139
|
+
# Refactored to use Fiddle-based RuntimeAPI methods.
|
|
140
|
+
class Event
|
|
141
|
+
# @return [Fiddle::Pointer] Event handle
|
|
142
|
+
attr_reader :handle
|
|
143
|
+
|
|
144
|
+
# @param blocking [Boolean] If true, CPU will block on synchronize
|
|
145
|
+
# @param disable_timing [Boolean] If true, timing is disabled for better performance
|
|
146
|
+
def initialize(blocking: false, disable_timing: false)
|
|
147
|
+
@blocking = blocking
|
|
148
|
+
@disable_timing = disable_timing
|
|
149
|
+
@handle = create_event
|
|
150
|
+
@destroyed = false
|
|
151
|
+
|
|
152
|
+
captured_handle = @handle
|
|
153
|
+
ObjectSpace.define_finalizer(self, self.class.release_finalizer(captured_handle))
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
# @return [Boolean]
|
|
157
|
+
def destroyed?
|
|
158
|
+
@destroyed
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
# Record the event on a stream.
|
|
162
|
+
# @param stream [Stream, nil] Stream to record on (nil for default)
|
|
163
|
+
# @return [void]
|
|
164
|
+
def record(stream: nil)
|
|
165
|
+
raise InvalidOperationError, 'Event has been destroyed' if @destroyed
|
|
166
|
+
|
|
167
|
+
RuntimeAPI.ensure_loaded!
|
|
168
|
+
stream_handle = stream&.to_ptr || Fiddle::Pointer.new(0)
|
|
169
|
+
RuntimeAPI.event_record(@handle, stream_handle)
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
# Wait for the event to complete.
|
|
173
|
+
# @return [void]
|
|
174
|
+
def synchronize
|
|
175
|
+
raise InvalidOperationError, 'Event has been destroyed' if @destroyed
|
|
176
|
+
|
|
177
|
+
RuntimeAPI.ensure_loaded!
|
|
178
|
+
RuntimeAPI.event_synchronize(@handle)
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
# @return [Fiddle::Pointer]
|
|
182
|
+
def to_ptr
|
|
183
|
+
@handle
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
# Calculate elapsed time between two events.
|
|
187
|
+
# @param start_event [Event] Start event
|
|
188
|
+
# @param end_event [Event] End event
|
|
189
|
+
# @return [Float] Elapsed time in milliseconds
|
|
190
|
+
def self.elapsed_time(start_event, end_event)
|
|
191
|
+
RuntimeAPI.ensure_loaded!
|
|
192
|
+
RuntimeAPI.event_elapsed_time(start_event.handle, end_event.handle)
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
# Destroy the event.
|
|
196
|
+
# @return [void]
|
|
197
|
+
def destroy!
|
|
198
|
+
return if @destroyed
|
|
199
|
+
|
|
200
|
+
RuntimeAPI.ensure_loaded!
|
|
201
|
+
RuntimeAPI.event_destroy(@handle)
|
|
202
|
+
|
|
203
|
+
@destroyed = true
|
|
204
|
+
ObjectSpace.undefine_finalizer(self)
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
class << self
|
|
208
|
+
# Create a finalizer for event cleanup.
|
|
209
|
+
# @param handle [Fiddle::Pointer]
|
|
210
|
+
# @return [Proc]
|
|
211
|
+
def release_finalizer(handle)
|
|
212
|
+
handle_addr = handle.to_i
|
|
213
|
+
proc do
|
|
214
|
+
begin
|
|
215
|
+
RuntimeAPI.ensure_loaded!
|
|
216
|
+
RuntimeAPI.event_destroy(Fiddle::Pointer.new(handle_addr))
|
|
217
|
+
rescue StandardError
|
|
218
|
+
# Silently ignore errors during finalization
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
end
|
|
222
|
+
end
|
|
223
|
+
|
|
224
|
+
private
|
|
225
|
+
|
|
226
|
+
# Create a CUDA event.
|
|
227
|
+
# @return [Fiddle::Pointer]
|
|
228
|
+
def create_event
|
|
229
|
+
RuntimeAPI.ensure_loaded!
|
|
230
|
+
RuntimeAPI.event_create
|
|
231
|
+
end
|
|
232
|
+
end
|
|
233
|
+
end
|
|
234
|
+
end
|
data/lib/nvruby/dtype.rb
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
# Supported data types with their properties
|
|
5
|
+
module DType
|
|
6
|
+
# Data type definitions with byte size and FFI type
|
|
7
|
+
TYPES = {
|
|
8
|
+
float16: { bytes: 2, ffi: :uint16, cuda: 2, complex: false }, # CUDA_R_16F (stored as uint16)
|
|
9
|
+
bfloat16: { bytes: 2, ffi: :uint16, cuda: 14, complex: false }, # CUDA_R_16BF (stored as uint16)
|
|
10
|
+
float32: { bytes: 4, ffi: :float, cuda: 0, complex: false },
|
|
11
|
+
float64: { bytes: 8, ffi: :double, cuda: 1, complex: false },
|
|
12
|
+
int8: { bytes: 1, ffi: :int8, cuda: 3, complex: false },
|
|
13
|
+
int16: { bytes: 2, ffi: :int16, cuda: 8, complex: false },
|
|
14
|
+
int32: { bytes: 4, ffi: :int32, cuda: 10, complex: false },
|
|
15
|
+
int64: { bytes: 8, ffi: :int64, cuda: 11, complex: false },
|
|
16
|
+
uint8: { bytes: 1, ffi: :uint8, cuda: 17, complex: false },
|
|
17
|
+
uint16: { bytes: 2, ffi: :uint16, cuda: 18, complex: false },
|
|
18
|
+
uint32: { bytes: 4, ffi: :uint32, cuda: 19, complex: false },
|
|
19
|
+
uint64: { bytes: 8, ffi: :uint64, cuda: 20, complex: false },
|
|
20
|
+
complex64: { bytes: 8, ffi: :float, cuda: 4, complex: true }, # 2x float32
|
|
21
|
+
complex128: { bytes: 16, ffi: :double, cuda: 5, complex: true } # 2x float64
|
|
22
|
+
}.freeze
|
|
23
|
+
|
|
24
|
+
# cuBLAS data type constants (cudaDataType_t values)
|
|
25
|
+
CUBLAS_TYPES = {
|
|
26
|
+
float16: 2, # CUDA_R_16F
|
|
27
|
+
bfloat16: 14, # CUDA_R_16BF
|
|
28
|
+
float32: 0, # CUDA_R_32F
|
|
29
|
+
float64: 1, # CUDA_R_64F
|
|
30
|
+
complex64: 4, # CUDA_C_32F
|
|
31
|
+
complex128: 5, # CUDA_C_64F
|
|
32
|
+
int8: 3, # CUDA_R_8I
|
|
33
|
+
int32: 10 # CUDA_R_32I
|
|
34
|
+
}.freeze
|
|
35
|
+
|
|
36
|
+
class << self
|
|
37
|
+
# Get byte size for a dtype
|
|
38
|
+
# @param dtype [Symbol] Data type
|
|
39
|
+
# @return [Integer] Size in bytes
|
|
40
|
+
def byte_size(dtype)
|
|
41
|
+
info = TYPES[dtype]
|
|
42
|
+
raise UnsupportedDTypeError, dtype unless info
|
|
43
|
+
|
|
44
|
+
info[:bytes]
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Get FFI type for a dtype
|
|
48
|
+
# @param dtype [Symbol] Data type
|
|
49
|
+
# @return [Symbol] FFI type symbol
|
|
50
|
+
def ffi_type(dtype)
|
|
51
|
+
info = TYPES[dtype]
|
|
52
|
+
raise UnsupportedDTypeError, dtype unless info
|
|
53
|
+
|
|
54
|
+
info[:ffi]
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
# Check if dtype is complex
|
|
58
|
+
# @param dtype [Symbol] Data type
|
|
59
|
+
# @return [Boolean]
|
|
60
|
+
def complex?(dtype)
|
|
61
|
+
info = TYPES[dtype]
|
|
62
|
+
raise UnsupportedDTypeError, dtype unless info
|
|
63
|
+
|
|
64
|
+
info[:complex]
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Check if dtype is floating point
|
|
68
|
+
# @param dtype [Symbol] Data type
|
|
69
|
+
# @return [Boolean]
|
|
70
|
+
def float?(dtype)
|
|
71
|
+
%i[float16 bfloat16 float32 float64 complex64 complex128].include?(dtype)
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Check if dtype is integer
|
|
75
|
+
# @param dtype [Symbol] Data type
|
|
76
|
+
# @return [Boolean]
|
|
77
|
+
def integer?(dtype)
|
|
78
|
+
%i[int8 int16 int32 int64 uint8 uint16 uint32 uint64].include?(dtype)
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Check if dtype is signed
|
|
82
|
+
# @param dtype [Symbol] Data type
|
|
83
|
+
# @return [Boolean]
|
|
84
|
+
def signed?(dtype)
|
|
85
|
+
!%i[uint8 uint16 uint32 uint64].include?(dtype)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# Get cuBLAS type constant
|
|
89
|
+
# @param dtype [Symbol] Data type
|
|
90
|
+
# @return [Integer] cuBLAS type constant
|
|
91
|
+
def cublas_type(dtype)
|
|
92
|
+
type = CUBLAS_TYPES[dtype]
|
|
93
|
+
raise UnsupportedDTypeError.new(dtype, operation: "cuBLAS") unless type
|
|
94
|
+
|
|
95
|
+
type
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Validate dtype
|
|
99
|
+
# @param dtype [Symbol] Data type to validate
|
|
100
|
+
# @return [Symbol] Validated dtype
|
|
101
|
+
# @raise [UnsupportedDTypeError] If dtype is not valid
|
|
102
|
+
def validate!(dtype)
|
|
103
|
+
raise UnsupportedDTypeError, dtype unless TYPES.key?(dtype)
|
|
104
|
+
|
|
105
|
+
dtype
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
# List all supported dtypes
|
|
109
|
+
# @return [Array<Symbol>] All dtype symbols
|
|
110
|
+
def all
|
|
111
|
+
TYPES.keys
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Get real component dtype for complex types
|
|
115
|
+
# @param dtype [Symbol] Complex data type
|
|
116
|
+
# @return [Symbol] Real component dtype
|
|
117
|
+
def real_dtype(dtype)
|
|
118
|
+
case dtype
|
|
119
|
+
when :complex64 then :float32
|
|
120
|
+
when :complex128 then :float64
|
|
121
|
+
else dtype
|
|
122
|
+
end
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Get complex dtype for real types
|
|
126
|
+
# @param dtype [Symbol] Real data type
|
|
127
|
+
# @return [Symbol] Complex dtype
|
|
128
|
+
def complex_dtype(dtype)
|
|
129
|
+
case dtype
|
|
130
|
+
when :float32 then :complex64
|
|
131
|
+
when :float64 then :complex128
|
|
132
|
+
when :complex64, :complex128 then dtype
|
|
133
|
+
else
|
|
134
|
+
raise UnsupportedDTypeError.new(dtype, operation: "complex conversion")
|
|
135
|
+
end
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|