ignis 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis.rb +94 -0
- data/lib/nnw/platform.rb +304 -0
- data/lib/nnw/shared/event_bus.rb +240 -0
- data/lib/nnw/shared/ffi_loader.rb +63 -0
- data/lib/nnw/shared/memory_contract.rb +204 -0
- data/lib/nnw/shared/nv_array.rb +710 -0
- data/lib/nnw/shared/recovery_protocol.rb +307 -0
- data/lib/nvruby/configuration.rb +217 -0
- data/lib/nvruby/cuda/device.rb +275 -0
- data/lib/nvruby/cuda/device_props.rb +202 -0
- data/lib/nvruby/cuda/graph.rb +265 -0
- data/lib/nvruby/cuda/graph_bindings.rb +119 -0
- data/lib/nvruby/cuda/library_loader.rb +285 -0
- data/lib/nvruby/cuda/memory.rb +410 -0
- data/lib/nvruby/cuda/runtime_api.rb +804 -0
- data/lib/nvruby/cuda/stream.rb +234 -0
- data/lib/nvruby/dtype.rb +139 -0
- data/lib/nvruby/epilogues.rb +438 -0
- data/lib/nvruby/errors.rb +303 -0
- data/lib/nvruby/half.rb +97 -0
- data/lib/nvruby/jit/compiled_kernel.rb +80 -0
- data/lib/nvruby/jit/compiler.rb +231 -0
- data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
- data/lib/nvruby/jit/kernel.rb +240 -0
- data/lib/nvruby/jit/kernel_module.rb +133 -0
- data/lib/nvruby/jit/kernels/activations.rb +179 -0
- data/lib/nvruby/jit/kernels/attention.rb +504 -0
- data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
- data/lib/nvruby/jit/kernels/loss.rb +213 -0
- data/lib/nvruby/jit/kernels/normalization.rb +200 -0
- data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
- data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
- data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
- data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
- data/lib/nvruby/linalg/epilog.rb +67 -0
- data/lib/nvruby/linalg/matmul.rb +247 -0
- data/lib/nvruby/linalg/matmul_plan.rb +229 -0
- data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
- data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
- data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
- data/lib/nvruby/memory/device_memory_resource.rb +106 -0
- data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
- data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
- data/lib/nvruby/memory/stats.rb +107 -0
- data/lib/nvruby/memory.rb +124 -0
- data/lib/nvruby/version.rb +5 -0
- metadata +108 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module JIT
|
|
5
|
+
module Kernels
|
|
6
|
+
# Elementwise CUDA kernels for AI tensor operations.
|
|
7
|
+
# Includes arithmetic ops, initialization, and embedding ops.
|
|
8
|
+
module Elementwise
|
|
9
|
+
class << self
|
|
10
|
+
# Elementwise addition forward: c = a + b
|
|
11
|
+
# @return [Ignis::JIT::Kernel]
|
|
12
|
+
def add_forward
|
|
13
|
+
source = <<~CUDA
|
|
14
|
+
extern "C" __global__
|
|
15
|
+
void add_forward(const float* __restrict__ a,
|
|
16
|
+
const float* __restrict__ b,
|
|
17
|
+
float* __restrict__ c,
|
|
18
|
+
const int n) {
|
|
19
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
20
|
+
if (idx < n) {
|
|
21
|
+
c[idx] = a[idx] + b[idx];
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
CUDA
|
|
25
|
+
compile_cached(source, "add_forward")
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Elementwise addition backward: grad passes through to both inputs
|
|
29
|
+
# (identity for add — no separate kernel needed, but useful for scalar broadcast)
|
|
30
|
+
# @return [Ignis::JIT::Kernel]
|
|
31
|
+
def add_backward_broadcast
|
|
32
|
+
source = <<~CUDA
|
|
33
|
+
extern "C" __global__
|
|
34
|
+
void add_backward_broadcast(const float* __restrict__ grad_output,
|
|
35
|
+
float* __restrict__ grad_bias,
|
|
36
|
+
const int batch_size,
|
|
37
|
+
const int features) {
|
|
38
|
+
int f = blockIdx.x * blockDim.x + threadIdx.x;
|
|
39
|
+
if (f < features) {
|
|
40
|
+
float sum = 0.0f;
|
|
41
|
+
for (int b = 0; b < batch_size; b++) {
|
|
42
|
+
sum += grad_output[b * features + f];
|
|
43
|
+
}
|
|
44
|
+
grad_bias[f] = sum;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
CUDA
|
|
48
|
+
compile_cached(source, "add_backward_broadcast")
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Elementwise subtraction forward: c = a - b
|
|
52
|
+
# @return [Ignis::JIT::Kernel]
|
|
53
|
+
def sub_forward
|
|
54
|
+
source = <<~CUDA
|
|
55
|
+
extern "C" __global__
|
|
56
|
+
void sub_forward(const float* __restrict__ a,
|
|
57
|
+
const float* __restrict__ b,
|
|
58
|
+
float* __restrict__ c,
|
|
59
|
+
const int n) {
|
|
60
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
61
|
+
if (idx < n) {
|
|
62
|
+
c[idx] = a[idx] - b[idx];
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
CUDA
|
|
66
|
+
compile_cached(source, "sub_forward")
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Elementwise multiplication forward: c = a * b (Hadamard product)
|
|
70
|
+
# @return [Ignis::JIT::Kernel]
|
|
71
|
+
def mul_forward
|
|
72
|
+
source = <<~CUDA
|
|
73
|
+
extern "C" __global__
|
|
74
|
+
void mul_forward(const float* __restrict__ a,
|
|
75
|
+
const float* __restrict__ b,
|
|
76
|
+
float* __restrict__ c,
|
|
77
|
+
const int n) {
|
|
78
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
79
|
+
if (idx < n) {
|
|
80
|
+
c[idx] = a[idx] * b[idx];
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
CUDA
|
|
84
|
+
compile_cached(source, "mul_forward")
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Elementwise multiply backward for first operand: grad_a = grad * b
|
|
88
|
+
# @return [Ignis::JIT::Kernel]
|
|
89
|
+
def mul_backward
|
|
90
|
+
source = <<~CUDA
|
|
91
|
+
extern "C" __global__
|
|
92
|
+
void mul_backward(const float* __restrict__ grad_output,
|
|
93
|
+
const float* __restrict__ other,
|
|
94
|
+
float* __restrict__ grad_input,
|
|
95
|
+
const int n) {
|
|
96
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
97
|
+
if (idx < n) {
|
|
98
|
+
grad_input[idx] = grad_output[idx] * other[idx];
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
CUDA
|
|
102
|
+
compile_cached(source, "mul_backward")
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# Elementwise minimum: c = min(a, b) (used by collective reductions)
|
|
106
|
+
# @return [Ignis::JIT::Kernel]
|
|
107
|
+
def min_forward
|
|
108
|
+
source = <<~CUDA
|
|
109
|
+
extern "C" __global__
|
|
110
|
+
void min_forward(const float* __restrict__ a,
|
|
111
|
+
const float* __restrict__ b,
|
|
112
|
+
float* __restrict__ c,
|
|
113
|
+
const int n) {
|
|
114
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
115
|
+
if (idx < n) {
|
|
116
|
+
c[idx] = fminf(a[idx], b[idx]);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
CUDA
|
|
120
|
+
compile_cached(source, "min_forward")
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Elementwise maximum: c = max(a, b) (used by collective reductions)
|
|
124
|
+
# @return [Ignis::JIT::Kernel]
|
|
125
|
+
def max_forward
|
|
126
|
+
source = <<~CUDA
|
|
127
|
+
extern "C" __global__
|
|
128
|
+
void max_forward(const float* __restrict__ a,
|
|
129
|
+
const float* __restrict__ b,
|
|
130
|
+
float* __restrict__ c,
|
|
131
|
+
const int n) {
|
|
132
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
133
|
+
if (idx < n) {
|
|
134
|
+
c[idx] = fmaxf(a[idx], b[idx]);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
CUDA
|
|
138
|
+
compile_cached(source, "max_forward")
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
# Scalar multiplication: output = input * scalar
|
|
142
|
+
# @return [Ignis::JIT::Kernel]
|
|
143
|
+
def scale_forward
|
|
144
|
+
source = <<~CUDA
|
|
145
|
+
extern "C" __global__
|
|
146
|
+
void scale_forward(const float* __restrict__ input,
|
|
147
|
+
float* __restrict__ output,
|
|
148
|
+
const float scalar,
|
|
149
|
+
const int n) {
|
|
150
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
151
|
+
if (idx < n) {
|
|
152
|
+
output[idx] = input[idx] * scalar;
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
CUDA
|
|
156
|
+
compile_cached(source, "scale_forward")
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
# Fill tensor with a constant value
|
|
160
|
+
# @return [Ignis::JIT::Kernel]
|
|
161
|
+
def fill
|
|
162
|
+
source = <<~CUDA
|
|
163
|
+
extern "C" __global__
|
|
164
|
+
void fill(float* __restrict__ output,
|
|
165
|
+
const float value,
|
|
166
|
+
const int n) {
|
|
167
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
168
|
+
if (idx < n) {
|
|
169
|
+
output[idx] = value;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
CUDA
|
|
173
|
+
compile_cached(source, "fill")
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
# Kaiming uniform initialization: U(-bound, bound)
|
|
177
|
+
# Uses cuRAND-style Philox counter-based generator for reproducibility
|
|
178
|
+
# @return [Ignis::JIT::Kernel]
|
|
179
|
+
def kaiming_uniform_init
|
|
180
|
+
source = <<~CUDA
|
|
181
|
+
extern "C" __global__
|
|
182
|
+
void kaiming_uniform_init(float* __restrict__ output,
|
|
183
|
+
const float bound,
|
|
184
|
+
const unsigned long long seed,
|
|
185
|
+
const int n) {
|
|
186
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
187
|
+
if (idx < n) {
|
|
188
|
+
unsigned long long state = seed + (unsigned long long)idx;
|
|
189
|
+
state ^= state >> 33;
|
|
190
|
+
state *= 0xff51afd7ed558ccdULL;
|
|
191
|
+
state ^= state >> 33;
|
|
192
|
+
state *= 0xc4ceb9fe1a85ec53ULL;
|
|
193
|
+
state ^= state >> 33;
|
|
194
|
+
float u = (float)(state & 0xFFFFFFFF) / 4294967296.0f;
|
|
195
|
+
output[idx] = (2.0f * u - 1.0f) * bound;
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
CUDA
|
|
199
|
+
compile_cached(source, "kaiming_uniform_init")
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# Gather rows for Embedding forward: output[i] = weight[indices[i]]
|
|
203
|
+
# @return [Ignis::JIT::Kernel]
|
|
204
|
+
def gather_rows
|
|
205
|
+
source = <<~CUDA
|
|
206
|
+
extern "C" __global__
|
|
207
|
+
void gather_rows(const float* __restrict__ weight,
|
|
208
|
+
const int* __restrict__ indices,
|
|
209
|
+
float* __restrict__ output,
|
|
210
|
+
const int num_indices,
|
|
211
|
+
const int embed_dim) {
|
|
212
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
213
|
+
int total = num_indices * embed_dim;
|
|
214
|
+
if (idx < total) {
|
|
215
|
+
int row = idx / embed_dim;
|
|
216
|
+
int col = idx % embed_dim;
|
|
217
|
+
int src_row = indices[row];
|
|
218
|
+
output[idx] = weight[src_row * embed_dim + col];
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
CUDA
|
|
222
|
+
compile_cached(source, "gather_rows")
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
# Scatter add for Embedding backward: weight_grad[indices[i]] += grad[i]
|
|
226
|
+
# Uses atomicAdd for thread safety
|
|
227
|
+
# @return [Ignis::JIT::Kernel]
|
|
228
|
+
def scatter_add
|
|
229
|
+
source = <<~CUDA
|
|
230
|
+
extern "C" __global__
|
|
231
|
+
void scatter_add(const float* __restrict__ grad_output,
|
|
232
|
+
const int* __restrict__ indices,
|
|
233
|
+
float* __restrict__ grad_weight,
|
|
234
|
+
const int num_indices,
|
|
235
|
+
const int embed_dim) {
|
|
236
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
237
|
+
int total = num_indices * embed_dim;
|
|
238
|
+
if (idx < total) {
|
|
239
|
+
int row = idx / embed_dim;
|
|
240
|
+
int col = idx % embed_dim;
|
|
241
|
+
int dst_row = indices[row];
|
|
242
|
+
atomicAdd(&grad_weight[dst_row * embed_dim + col], grad_output[idx]);
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
CUDA
|
|
246
|
+
compile_cached(source, "scatter_add")
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
# Accumulate gradients: dst += src (for gradient accumulation)
|
|
250
|
+
# @return [Ignis::JIT::Kernel]
|
|
251
|
+
def accumulate
|
|
252
|
+
source = <<~CUDA
|
|
253
|
+
extern "C" __global__
|
|
254
|
+
void accumulate(float* __restrict__ dst,
|
|
255
|
+
const float* __restrict__ src,
|
|
256
|
+
const int n) {
|
|
257
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
258
|
+
if (idx < n) {
|
|
259
|
+
dst[idx] += src[idx];
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
CUDA
|
|
263
|
+
compile_cached(source, "accumulate")
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
# Sum reduction along the last dimension
|
|
267
|
+
# @return [Ignis::JIT::Kernel]
|
|
268
|
+
def sum_reduce
|
|
269
|
+
source = <<~CUDA
|
|
270
|
+
extern "C" __global__
|
|
271
|
+
void sum_reduce(const float* __restrict__ input,
|
|
272
|
+
float* __restrict__ output,
|
|
273
|
+
const int outer_size,
|
|
274
|
+
const int reduce_size) {
|
|
275
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
276
|
+
if (idx < outer_size) {
|
|
277
|
+
float sum = 0.0f;
|
|
278
|
+
for (int j = 0; j < reduce_size; j++) {
|
|
279
|
+
sum += input[idx * reduce_size + j];
|
|
280
|
+
}
|
|
281
|
+
output[idx] = sum;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
CUDA
|
|
285
|
+
compile_cached(source, "sum_reduce")
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
# Broadcast scalar gradient back to original shape
|
|
289
|
+
# @return [Ignis::JIT::Kernel]
|
|
290
|
+
def broadcast_grad
|
|
291
|
+
source = <<~CUDA
|
|
292
|
+
extern "C" __global__
|
|
293
|
+
void broadcast_grad(const float* __restrict__ grad_output,
|
|
294
|
+
float* __restrict__ grad_input,
|
|
295
|
+
const float scale,
|
|
296
|
+
const int n) {
|
|
297
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
298
|
+
if (idx < n) {
|
|
299
|
+
grad_input[idx] = grad_output[0] * scale;
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
CUDA
|
|
303
|
+
compile_cached(source, "broadcast_grad")
|
|
304
|
+
end
|
|
305
|
+
|
|
306
|
+
# Transpose 2D matrix: output[j,i] = input[i,j]
|
|
307
|
+
# Tiled for coalesced memory access
|
|
308
|
+
# @return [Ignis::JIT::Kernel]
|
|
309
|
+
def transpose_2d
|
|
310
|
+
source = <<~CUDA
|
|
311
|
+
#define TILE_DIM 32
|
|
312
|
+
#define BLOCK_ROWS 8
|
|
313
|
+
|
|
314
|
+
extern "C" __global__
|
|
315
|
+
void transpose_2d(const float* __restrict__ input,
|
|
316
|
+
float* __restrict__ output,
|
|
317
|
+
const int rows,
|
|
318
|
+
const int cols) {
|
|
319
|
+
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
|
|
320
|
+
|
|
321
|
+
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
|
322
|
+
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
|
323
|
+
|
|
324
|
+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
|
|
325
|
+
if (x < cols && (y + j) < rows) {
|
|
326
|
+
tile[threadIdx.y + j][threadIdx.x] = input[(y + j) * cols + x];
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
__syncthreads();
|
|
330
|
+
|
|
331
|
+
x = blockIdx.y * TILE_DIM + threadIdx.x;
|
|
332
|
+
y = blockIdx.x * TILE_DIM + threadIdx.y;
|
|
333
|
+
|
|
334
|
+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
|
|
335
|
+
if (x < rows && (y + j) < cols) {
|
|
336
|
+
output[(y + j) * rows + x] = tile[threadIdx.x][threadIdx.y + j];
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
CUDA
|
|
341
|
+
compile_cached(source, "transpose_2d")
|
|
342
|
+
end
|
|
343
|
+
|
|
344
|
+
# Copy a contiguous column range [col_off, col_off+len) from each row.
|
|
345
|
+
# dst[r, c] = src[r, col_off + c] (dst is [rows, len], src is [rows, total_cols]).
|
|
346
|
+
# Used to split [seq, embed] projections into per-head [seq, head_dim] slices.
|
|
347
|
+
# @return [Ignis::JIT::Kernel]
|
|
348
|
+
def slice_cols
|
|
349
|
+
source = <<~CUDA
|
|
350
|
+
extern "C" __global__
|
|
351
|
+
void slice_cols(const float* __restrict__ src,
|
|
352
|
+
float* __restrict__ dst,
|
|
353
|
+
const int rows, const int total_cols,
|
|
354
|
+
const int col_off, const int len) {
|
|
355
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
356
|
+
int total = rows * len;
|
|
357
|
+
if (idx < total) {
|
|
358
|
+
int r = idx / len;
|
|
359
|
+
int c = idx % len;
|
|
360
|
+
dst[idx] = src[r * total_cols + col_off + c];
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
CUDA
|
|
364
|
+
compile_cached(source, "slice_cols")
|
|
365
|
+
end
|
|
366
|
+
|
|
367
|
+
# Inverse of slice_cols: dst[r, col_off + c] = src[r, c].
|
|
368
|
+
# Used to scatter per-head [seq, head_dim] results back into [seq, embed].
|
|
369
|
+
# @return [Ignis::JIT::Kernel]
|
|
370
|
+
def scatter_cols
|
|
371
|
+
source = <<~CUDA
|
|
372
|
+
extern "C" __global__
|
|
373
|
+
void scatter_cols(const float* __restrict__ src,
|
|
374
|
+
float* __restrict__ dst,
|
|
375
|
+
const int rows, const int total_cols,
|
|
376
|
+
const int col_off, const int len) {
|
|
377
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
378
|
+
int total = rows * len;
|
|
379
|
+
if (idx < total) {
|
|
380
|
+
int r = idx / len;
|
|
381
|
+
int c = idx % len;
|
|
382
|
+
dst[r * total_cols + col_off + c] = src[idx];
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
CUDA
|
|
386
|
+
compile_cached(source, "scatter_cols")
|
|
387
|
+
end
|
|
388
|
+
|
|
389
|
+
# Accumulating scatter: dst[r, col_off + c] += src[r, c].
|
|
390
|
+
# Used for GQA backward, where the group_size query heads sharing one KV
|
|
391
|
+
# head each contribute to the same dK/dV columns — their gradients must
|
|
392
|
+
# SUM, not overwrite. (Columns are disjoint across rows, so no atomics
|
|
393
|
+
# are needed: each (r, col_off+c) is written by exactly one thread here;
|
|
394
|
+
# accumulation across heads happens via separate launches into the buffer.)
|
|
395
|
+
# @return [Ignis::JIT::Kernel]
|
|
396
|
+
def scatter_cols_add
|
|
397
|
+
source = <<~CUDA
|
|
398
|
+
extern "C" __global__
|
|
399
|
+
void scatter_cols_add(const float* __restrict__ src,
|
|
400
|
+
float* __restrict__ dst,
|
|
401
|
+
const int rows, const int total_cols,
|
|
402
|
+
const int col_off, const int len) {
|
|
403
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
404
|
+
int total = rows * len;
|
|
405
|
+
if (idx < total) {
|
|
406
|
+
int r = idx / len;
|
|
407
|
+
int c = idx % len;
|
|
408
|
+
dst[r * total_cols + col_off + c] += src[idx];
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
CUDA
|
|
412
|
+
compile_cached(source, "scatter_cols_add")
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
# Dequantize bfloat16 → float32 on-device. bf16 is exactly the top 16 bits
|
|
416
|
+
# of an fp32 value (same sign/exponent, truncated mantissa), so widening is
|
|
417
|
+
# lossless: float32_bits = uint16(bf16) << 16. Lets us load bf16 checkpoints
|
|
418
|
+
# (e.g. Llama) into fp32 weights without materializing a giant host array.
|
|
419
|
+
# @return [Ignis::JIT::Kernel]
|
|
420
|
+
def bf16_to_f32
|
|
421
|
+
source = <<~CUDA
|
|
422
|
+
extern "C" __global__
|
|
423
|
+
void bf16_to_f32(const unsigned short* __restrict__ src,
|
|
424
|
+
float* __restrict__ dst,
|
|
425
|
+
const int n) {
|
|
426
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
427
|
+
if (i < n) {
|
|
428
|
+
unsigned int bits = ((unsigned int)src[i]) << 16;
|
|
429
|
+
dst[i] = __uint_as_float(bits);
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
CUDA
|
|
433
|
+
compile_cached(source, "bf16_to_f32")
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
# Affine transform: output = input * scale + shift (fp32).
|
|
437
|
+
# Used e.g. to map cuRAND U[0,1) into U[low, high).
|
|
438
|
+
# @return [Ignis::JIT::Kernel]
|
|
439
|
+
def affine_forward
|
|
440
|
+
source = <<~CUDA
|
|
441
|
+
extern "C" __global__
|
|
442
|
+
void affine_forward(const float* __restrict__ input,
|
|
443
|
+
float* __restrict__ output,
|
|
444
|
+
const float scale, const float shift,
|
|
445
|
+
const int n) {
|
|
446
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
447
|
+
if (idx < n) {
|
|
448
|
+
output[idx] = input[idx] * scale + shift;
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
CUDA
|
|
452
|
+
compile_cached(source, "affine_forward")
|
|
453
|
+
end
|
|
454
|
+
|
|
455
|
+
# Row-broadcast bias add: out[r, c] = a[r, c] + bias[c]
|
|
456
|
+
# (a is [rows, cols], bias is [cols]). Linear layer bias.
|
|
457
|
+
# @return [Ignis::JIT::Kernel]
|
|
458
|
+
def add_bias_rows
|
|
459
|
+
source = <<~CUDA
|
|
460
|
+
extern "C" __global__
|
|
461
|
+
void add_bias_rows(const float* __restrict__ a,
|
|
462
|
+
const float* __restrict__ bias,
|
|
463
|
+
float* __restrict__ out,
|
|
464
|
+
const int rows, const int cols) {
|
|
465
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
466
|
+
int total = rows * cols;
|
|
467
|
+
if (idx < total) {
|
|
468
|
+
out[idx] = a[idx] + bias[idx % cols];
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
CUDA
|
|
472
|
+
compile_cached(source, "add_bias_rows")
|
|
473
|
+
end
|
|
474
|
+
|
|
475
|
+
private
|
|
476
|
+
|
|
477
|
+
# @param source [String] CUDA source code
|
|
478
|
+
# @param name [String] kernel function name
|
|
479
|
+
# @param device_id [Integer]
|
|
480
|
+
# @return [Ignis::JIT::Kernel]
|
|
481
|
+
def compile_cached(source, name, device_id: 0)
|
|
482
|
+
Ignis::JIT::Compiler.compile(source, name, device_id: device_id)
|
|
483
|
+
end
|
|
484
|
+
end
|
|
485
|
+
end
|
|
486
|
+
end
|
|
487
|
+
end
|
|
488
|
+
end
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module JIT
|
|
5
|
+
module Kernels
|
|
6
|
+
# Loss function CUDA kernels for training.
|
|
7
|
+
# All are fused implementations for numerical stability and performance.
|
|
8
|
+
module Loss
|
|
9
|
+
class << self
|
|
10
|
+
# Fused cross-entropy forward: log_softmax + NLL in a single kernel
|
|
11
|
+
# Avoids materializing full log-softmax output
|
|
12
|
+
# @return [Ignis::JIT::Kernel]
|
|
13
|
+
def cross_entropy_forward
|
|
14
|
+
source = <<~CUDA
|
|
15
|
+
extern "C" __global__
|
|
16
|
+
void cross_entropy_forward(const float* __restrict__ logits,
|
|
17
|
+
const int* __restrict__ targets,
|
|
18
|
+
float* __restrict__ losses,
|
|
19
|
+
float* __restrict__ log_softmax_out,
|
|
20
|
+
const int batch_size,
|
|
21
|
+
const int vocab_size,
|
|
22
|
+
const float label_smoothing) {
|
|
23
|
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
24
|
+
if (row < batch_size) {
|
|
25
|
+
const float* row_logits = logits + row * vocab_size;
|
|
26
|
+
float* row_lsm = log_softmax_out + row * vocab_size;
|
|
27
|
+
int target = targets[row];
|
|
28
|
+
|
|
29
|
+
// Find max for numerical stability
|
|
30
|
+
float max_val = row_logits[0];
|
|
31
|
+
for (int j = 1; j < vocab_size; j++) {
|
|
32
|
+
max_val = fmaxf(max_val, row_logits[j]);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// log_softmax = x - max - log(sum(exp(x - max)))
|
|
36
|
+
float log_sum_exp = 0.0f;
|
|
37
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
38
|
+
log_sum_exp += expf(row_logits[j] - max_val);
|
|
39
|
+
}
|
|
40
|
+
log_sum_exp = logf(log_sum_exp);
|
|
41
|
+
|
|
42
|
+
// Compute log_softmax and store
|
|
43
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
44
|
+
row_lsm[j] = row_logits[j] - max_val - log_sum_exp;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// NLL loss with optional label smoothing
|
|
48
|
+
if (label_smoothing > 0.0f) {
|
|
49
|
+
float smooth_loss = 0.0f;
|
|
50
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
51
|
+
smooth_loss -= row_lsm[j];
|
|
52
|
+
}
|
|
53
|
+
smooth_loss /= (float)vocab_size;
|
|
54
|
+
float nll = -row_lsm[target];
|
|
55
|
+
losses[row] = (1.0f - label_smoothing) * nll + label_smoothing * smooth_loss;
|
|
56
|
+
} else {
|
|
57
|
+
losses[row] = -row_lsm[target];
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
CUDA
|
|
62
|
+
compile_cached(source, "cross_entropy_forward")
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
# Cross-entropy backward: softmax(logits) - one_hot(target)
|
|
66
|
+
# Combined softmax + gradient in one kernel
|
|
67
|
+
# @return [Ignis::JIT::Kernel]
|
|
68
|
+
def cross_entropy_backward
|
|
69
|
+
source = <<~CUDA
|
|
70
|
+
extern "C" __global__
|
|
71
|
+
void cross_entropy_backward(const float* __restrict__ log_softmax,
|
|
72
|
+
const int* __restrict__ targets,
|
|
73
|
+
const float* __restrict__ grad_output,
|
|
74
|
+
float* __restrict__ grad_logits,
|
|
75
|
+
const int batch_size,
|
|
76
|
+
const int vocab_size,
|
|
77
|
+
const float label_smoothing) {
|
|
78
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
79
|
+
int total = batch_size * vocab_size;
|
|
80
|
+
if (idx < total) {
|
|
81
|
+
int row = idx / vocab_size;
|
|
82
|
+
int col = idx % vocab_size;
|
|
83
|
+
int target = targets[row];
|
|
84
|
+
|
|
85
|
+
float softmax_val = expf(log_softmax[idx]);
|
|
86
|
+
float grad_scale = grad_output[row];
|
|
87
|
+
|
|
88
|
+
if (label_smoothing > 0.0f) {
|
|
89
|
+
float smooth_target = label_smoothing / (float)vocab_size;
|
|
90
|
+
float hard_target = (col == target) ? (1.0f - label_smoothing + smooth_target) : smooth_target;
|
|
91
|
+
grad_logits[idx] = grad_scale * (softmax_val - hard_target);
|
|
92
|
+
} else {
|
|
93
|
+
float indicator = (col == target) ? 1.0f : 0.0f;
|
|
94
|
+
grad_logits[idx] = grad_scale * (softmax_val - indicator);
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
CUDA
|
|
99
|
+
compile_cached(source, "cross_entropy_backward")
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# MSE forward: (pred - target)^2, per element
|
|
103
|
+
# @return [Ignis::JIT::Kernel]
|
|
104
|
+
def mse_forward
|
|
105
|
+
source = <<~CUDA
|
|
106
|
+
extern "C" __global__
|
|
107
|
+
void mse_forward(const float* __restrict__ predictions,
|
|
108
|
+
const float* __restrict__ targets,
|
|
109
|
+
float* __restrict__ losses,
|
|
110
|
+
const int n) {
|
|
111
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
112
|
+
if (idx < n) {
|
|
113
|
+
float diff = predictions[idx] - targets[idx];
|
|
114
|
+
losses[idx] = diff * diff;
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
CUDA
|
|
118
|
+
compile_cached(source, "mse_forward")
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
# MSE backward: 2 * (pred - target) / n
|
|
122
|
+
# @return [Ignis::JIT::Kernel]
|
|
123
|
+
def mse_backward
|
|
124
|
+
source = <<~CUDA
|
|
125
|
+
extern "C" __global__
|
|
126
|
+
void mse_backward(const float* __restrict__ predictions,
|
|
127
|
+
const float* __restrict__ targets,
|
|
128
|
+
const float* __restrict__ grad_output,
|
|
129
|
+
float* __restrict__ grad_input,
|
|
130
|
+
const int n,
|
|
131
|
+
const float scale) {
|
|
132
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
133
|
+
if (idx < n) {
|
|
134
|
+
grad_input[idx] = grad_output[idx] * 2.0f * (predictions[idx] - targets[idx]) * scale;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
CUDA
|
|
138
|
+
compile_cached(source, "mse_backward")
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
# Binary cross-entropy with logits: -[y*log(σ(x)) + (1-y)*log(1-σ(x))]
|
|
142
|
+
# @return [Ignis::JIT::Kernel]
|
|
143
|
+
def bce_forward
|
|
144
|
+
source = <<~CUDA
|
|
145
|
+
extern "C" __global__
|
|
146
|
+
void bce_forward(const float* __restrict__ logits,
|
|
147
|
+
const float* __restrict__ targets,
|
|
148
|
+
float* __restrict__ losses,
|
|
149
|
+
const int n) {
|
|
150
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
151
|
+
if (idx < n) {
|
|
152
|
+
float x = logits[idx];
|
|
153
|
+
float y = targets[idx];
|
|
154
|
+
// Numerically stable: max(x,0) - x*y + log(1+exp(-|x|))
|
|
155
|
+
float max_val = fmaxf(x, 0.0f);
|
|
156
|
+
losses[idx] = max_val - x * y + logf(1.0f + expf(-fabsf(x)));
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
CUDA
|
|
160
|
+
compile_cached(source, "bce_forward")
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
# BCE backward: σ(x) - y
|
|
164
|
+
# @return [Ignis::JIT::Kernel]
|
|
165
|
+
def bce_backward
|
|
166
|
+
source = <<~CUDA
|
|
167
|
+
extern "C" __global__
|
|
168
|
+
void bce_backward(const float* __restrict__ logits,
|
|
169
|
+
const float* __restrict__ targets,
|
|
170
|
+
const float* __restrict__ grad_output,
|
|
171
|
+
float* __restrict__ grad_input,
|
|
172
|
+
const int n) {
|
|
173
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
174
|
+
if (idx < n) {
|
|
175
|
+
float sig = 1.0f / (1.0f + expf(-logits[idx]));
|
|
176
|
+
grad_input[idx] = grad_output[idx] * (sig - targets[idx]);
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
CUDA
|
|
180
|
+
compile_cached(source, "bce_backward")
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
# Mean reduction: compute mean of array
|
|
184
|
+
# @return [Ignis::JIT::Kernel]
|
|
185
|
+
def mean_reduce
|
|
186
|
+
source = <<~CUDA
|
|
187
|
+
extern "C" __global__
|
|
188
|
+
void mean_reduce(const float* __restrict__ input,
|
|
189
|
+
float* __restrict__ output,
|
|
190
|
+
const int n) {
|
|
191
|
+
// Single-thread simple reduction (for loss scalar)
|
|
192
|
+
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
|
193
|
+
float sum = 0.0f;
|
|
194
|
+
for (int i = 0; i < n; i++) {
|
|
195
|
+
sum += input[i];
|
|
196
|
+
}
|
|
197
|
+
output[0] = sum / (float)n;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
CUDA
|
|
201
|
+
compile_cached(source, "mean_reduce")
|
|
202
|
+
end
|
|
203
|
+
|
|
204
|
+
private
|
|
205
|
+
|
|
206
|
+
def compile_cached(source, name, device_id: 0)
|
|
207
|
+
Ignis::JIT::Compiler.compile(source, name, device_id: device_id)
|
|
208
|
+
end
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
end
|
|
212
|
+
end
|
|
213
|
+
end
|