ignis 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +15 -0
  3. data/lib/ignis.rb +94 -0
  4. data/lib/nnw/platform.rb +304 -0
  5. data/lib/nnw/shared/event_bus.rb +240 -0
  6. data/lib/nnw/shared/ffi_loader.rb +63 -0
  7. data/lib/nnw/shared/memory_contract.rb +204 -0
  8. data/lib/nnw/shared/nv_array.rb +710 -0
  9. data/lib/nnw/shared/recovery_protocol.rb +307 -0
  10. data/lib/nvruby/configuration.rb +217 -0
  11. data/lib/nvruby/cuda/device.rb +275 -0
  12. data/lib/nvruby/cuda/device_props.rb +202 -0
  13. data/lib/nvruby/cuda/graph.rb +265 -0
  14. data/lib/nvruby/cuda/graph_bindings.rb +119 -0
  15. data/lib/nvruby/cuda/library_loader.rb +285 -0
  16. data/lib/nvruby/cuda/memory.rb +410 -0
  17. data/lib/nvruby/cuda/runtime_api.rb +804 -0
  18. data/lib/nvruby/cuda/stream.rb +234 -0
  19. data/lib/nvruby/dtype.rb +139 -0
  20. data/lib/nvruby/epilogues.rb +438 -0
  21. data/lib/nvruby/errors.rb +303 -0
  22. data/lib/nvruby/half.rb +97 -0
  23. data/lib/nvruby/jit/compiled_kernel.rb +80 -0
  24. data/lib/nvruby/jit/compiler.rb +231 -0
  25. data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
  26. data/lib/nvruby/jit/kernel.rb +240 -0
  27. data/lib/nvruby/jit/kernel_module.rb +133 -0
  28. data/lib/nvruby/jit/kernels/activations.rb +179 -0
  29. data/lib/nvruby/jit/kernels/attention.rb +504 -0
  30. data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
  31. data/lib/nvruby/jit/kernels/loss.rb +213 -0
  32. data/lib/nvruby/jit/kernels/normalization.rb +200 -0
  33. data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
  34. data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
  35. data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
  36. data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
  37. data/lib/nvruby/linalg/epilog.rb +67 -0
  38. data/lib/nvruby/linalg/matmul.rb +247 -0
  39. data/lib/nvruby/linalg/matmul_plan.rb +229 -0
  40. data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
  41. data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
  42. data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
  43. data/lib/nvruby/memory/device_memory_resource.rb +106 -0
  44. data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
  45. data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
  46. data/lib/nvruby/memory/stats.rb +107 -0
  47. data/lib/nvruby/memory.rb +124 -0
  48. data/lib/nvruby/version.rb +5 -0
  49. metadata +108 -0
@@ -0,0 +1,488 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module JIT
5
+ module Kernels
6
+ # Elementwise CUDA kernels for AI tensor operations.
7
+ # Includes arithmetic ops, initialization, and embedding ops.
8
+ module Elementwise
9
+ class << self
10
+ # Elementwise addition forward: c = a + b
11
+ # @return [Ignis::JIT::Kernel]
12
+ def add_forward
13
+ source = <<~CUDA
14
+ extern "C" __global__
15
+ void add_forward(const float* __restrict__ a,
16
+ const float* __restrict__ b,
17
+ float* __restrict__ c,
18
+ const int n) {
19
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
20
+ if (idx < n) {
21
+ c[idx] = a[idx] + b[idx];
22
+ }
23
+ }
24
+ CUDA
25
+ compile_cached(source, "add_forward")
26
+ end
27
+
28
+ # Elementwise addition backward: grad passes through to both inputs
29
+ # (identity for add — no separate kernel needed, but useful for scalar broadcast)
30
+ # @return [Ignis::JIT::Kernel]
31
+ def add_backward_broadcast
32
+ source = <<~CUDA
33
+ extern "C" __global__
34
+ void add_backward_broadcast(const float* __restrict__ grad_output,
35
+ float* __restrict__ grad_bias,
36
+ const int batch_size,
37
+ const int features) {
38
+ int f = blockIdx.x * blockDim.x + threadIdx.x;
39
+ if (f < features) {
40
+ float sum = 0.0f;
41
+ for (int b = 0; b < batch_size; b++) {
42
+ sum += grad_output[b * features + f];
43
+ }
44
+ grad_bias[f] = sum;
45
+ }
46
+ }
47
+ CUDA
48
+ compile_cached(source, "add_backward_broadcast")
49
+ end
50
+
51
+ # Elementwise subtraction forward: c = a - b
52
+ # @return [Ignis::JIT::Kernel]
53
+ def sub_forward
54
+ source = <<~CUDA
55
+ extern "C" __global__
56
+ void sub_forward(const float* __restrict__ a,
57
+ const float* __restrict__ b,
58
+ float* __restrict__ c,
59
+ const int n) {
60
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
61
+ if (idx < n) {
62
+ c[idx] = a[idx] - b[idx];
63
+ }
64
+ }
65
+ CUDA
66
+ compile_cached(source, "sub_forward")
67
+ end
68
+
69
+ # Elementwise multiplication forward: c = a * b (Hadamard product)
70
+ # @return [Ignis::JIT::Kernel]
71
+ def mul_forward
72
+ source = <<~CUDA
73
+ extern "C" __global__
74
+ void mul_forward(const float* __restrict__ a,
75
+ const float* __restrict__ b,
76
+ float* __restrict__ c,
77
+ const int n) {
78
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
79
+ if (idx < n) {
80
+ c[idx] = a[idx] * b[idx];
81
+ }
82
+ }
83
+ CUDA
84
+ compile_cached(source, "mul_forward")
85
+ end
86
+
87
+ # Elementwise multiply backward for first operand: grad_a = grad * b
88
+ # @return [Ignis::JIT::Kernel]
89
+ def mul_backward
90
+ source = <<~CUDA
91
+ extern "C" __global__
92
+ void mul_backward(const float* __restrict__ grad_output,
93
+ const float* __restrict__ other,
94
+ float* __restrict__ grad_input,
95
+ const int n) {
96
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
97
+ if (idx < n) {
98
+ grad_input[idx] = grad_output[idx] * other[idx];
99
+ }
100
+ }
101
+ CUDA
102
+ compile_cached(source, "mul_backward")
103
+ end
104
+
105
+ # Elementwise minimum: c = min(a, b) (used by collective reductions)
106
+ # @return [Ignis::JIT::Kernel]
107
+ def min_forward
108
+ source = <<~CUDA
109
+ extern "C" __global__
110
+ void min_forward(const float* __restrict__ a,
111
+ const float* __restrict__ b,
112
+ float* __restrict__ c,
113
+ const int n) {
114
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
115
+ if (idx < n) {
116
+ c[idx] = fminf(a[idx], b[idx]);
117
+ }
118
+ }
119
+ CUDA
120
+ compile_cached(source, "min_forward")
121
+ end
122
+
123
+ # Elementwise maximum: c = max(a, b) (used by collective reductions)
124
+ # @return [Ignis::JIT::Kernel]
125
+ def max_forward
126
+ source = <<~CUDA
127
+ extern "C" __global__
128
+ void max_forward(const float* __restrict__ a,
129
+ const float* __restrict__ b,
130
+ float* __restrict__ c,
131
+ const int n) {
132
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
133
+ if (idx < n) {
134
+ c[idx] = fmaxf(a[idx], b[idx]);
135
+ }
136
+ }
137
+ CUDA
138
+ compile_cached(source, "max_forward")
139
+ end
140
+
141
+ # Scalar multiplication: output = input * scalar
142
+ # @return [Ignis::JIT::Kernel]
143
+ def scale_forward
144
+ source = <<~CUDA
145
+ extern "C" __global__
146
+ void scale_forward(const float* __restrict__ input,
147
+ float* __restrict__ output,
148
+ const float scalar,
149
+ const int n) {
150
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
151
+ if (idx < n) {
152
+ output[idx] = input[idx] * scalar;
153
+ }
154
+ }
155
+ CUDA
156
+ compile_cached(source, "scale_forward")
157
+ end
158
+
159
+ # Fill tensor with a constant value
160
+ # @return [Ignis::JIT::Kernel]
161
+ def fill
162
+ source = <<~CUDA
163
+ extern "C" __global__
164
+ void fill(float* __restrict__ output,
165
+ const float value,
166
+ const int n) {
167
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
168
+ if (idx < n) {
169
+ output[idx] = value;
170
+ }
171
+ }
172
+ CUDA
173
+ compile_cached(source, "fill")
174
+ end
175
+
176
+ # Kaiming uniform initialization: U(-bound, bound)
177
+ # Uses cuRAND-style Philox counter-based generator for reproducibility
178
+ # @return [Ignis::JIT::Kernel]
179
+ def kaiming_uniform_init
180
+ source = <<~CUDA
181
+ extern "C" __global__
182
+ void kaiming_uniform_init(float* __restrict__ output,
183
+ const float bound,
184
+ const unsigned long long seed,
185
+ const int n) {
186
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
187
+ if (idx < n) {
188
+ unsigned long long state = seed + (unsigned long long)idx;
189
+ state ^= state >> 33;
190
+ state *= 0xff51afd7ed558ccdULL;
191
+ state ^= state >> 33;
192
+ state *= 0xc4ceb9fe1a85ec53ULL;
193
+ state ^= state >> 33;
194
+ float u = (float)(state & 0xFFFFFFFF) / 4294967296.0f;
195
+ output[idx] = (2.0f * u - 1.0f) * bound;
196
+ }
197
+ }
198
+ CUDA
199
+ compile_cached(source, "kaiming_uniform_init")
200
+ end
201
+
202
+ # Gather rows for Embedding forward: output[i] = weight[indices[i]]
203
+ # @return [Ignis::JIT::Kernel]
204
+ def gather_rows
205
+ source = <<~CUDA
206
+ extern "C" __global__
207
+ void gather_rows(const float* __restrict__ weight,
208
+ const int* __restrict__ indices,
209
+ float* __restrict__ output,
210
+ const int num_indices,
211
+ const int embed_dim) {
212
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
213
+ int total = num_indices * embed_dim;
214
+ if (idx < total) {
215
+ int row = idx / embed_dim;
216
+ int col = idx % embed_dim;
217
+ int src_row = indices[row];
218
+ output[idx] = weight[src_row * embed_dim + col];
219
+ }
220
+ }
221
+ CUDA
222
+ compile_cached(source, "gather_rows")
223
+ end
224
+
225
+ # Scatter add for Embedding backward: weight_grad[indices[i]] += grad[i]
226
+ # Uses atomicAdd for thread safety
227
+ # @return [Ignis::JIT::Kernel]
228
+ def scatter_add
229
+ source = <<~CUDA
230
+ extern "C" __global__
231
+ void scatter_add(const float* __restrict__ grad_output,
232
+ const int* __restrict__ indices,
233
+ float* __restrict__ grad_weight,
234
+ const int num_indices,
235
+ const int embed_dim) {
236
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
237
+ int total = num_indices * embed_dim;
238
+ if (idx < total) {
239
+ int row = idx / embed_dim;
240
+ int col = idx % embed_dim;
241
+ int dst_row = indices[row];
242
+ atomicAdd(&grad_weight[dst_row * embed_dim + col], grad_output[idx]);
243
+ }
244
+ }
245
+ CUDA
246
+ compile_cached(source, "scatter_add")
247
+ end
248
+
249
+ # Accumulate gradients: dst += src (for gradient accumulation)
250
+ # @return [Ignis::JIT::Kernel]
251
+ def accumulate
252
+ source = <<~CUDA
253
+ extern "C" __global__
254
+ void accumulate(float* __restrict__ dst,
255
+ const float* __restrict__ src,
256
+ const int n) {
257
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
258
+ if (idx < n) {
259
+ dst[idx] += src[idx];
260
+ }
261
+ }
262
+ CUDA
263
+ compile_cached(source, "accumulate")
264
+ end
265
+
266
+ # Sum reduction along the last dimension
267
+ # @return [Ignis::JIT::Kernel]
268
+ def sum_reduce
269
+ source = <<~CUDA
270
+ extern "C" __global__
271
+ void sum_reduce(const float* __restrict__ input,
272
+ float* __restrict__ output,
273
+ const int outer_size,
274
+ const int reduce_size) {
275
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
276
+ if (idx < outer_size) {
277
+ float sum = 0.0f;
278
+ for (int j = 0; j < reduce_size; j++) {
279
+ sum += input[idx * reduce_size + j];
280
+ }
281
+ output[idx] = sum;
282
+ }
283
+ }
284
+ CUDA
285
+ compile_cached(source, "sum_reduce")
286
+ end
287
+
288
+ # Broadcast scalar gradient back to original shape
289
+ # @return [Ignis::JIT::Kernel]
290
+ def broadcast_grad
291
+ source = <<~CUDA
292
+ extern "C" __global__
293
+ void broadcast_grad(const float* __restrict__ grad_output,
294
+ float* __restrict__ grad_input,
295
+ const float scale,
296
+ const int n) {
297
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
298
+ if (idx < n) {
299
+ grad_input[idx] = grad_output[0] * scale;
300
+ }
301
+ }
302
+ CUDA
303
+ compile_cached(source, "broadcast_grad")
304
+ end
305
+
306
+ # Transpose 2D matrix: output[j,i] = input[i,j]
307
+ # Tiled for coalesced memory access
308
+ # @return [Ignis::JIT::Kernel]
309
+ def transpose_2d
310
+ source = <<~CUDA
311
+ #define TILE_DIM 32
312
+ #define BLOCK_ROWS 8
313
+
314
+ extern "C" __global__
315
+ void transpose_2d(const float* __restrict__ input,
316
+ float* __restrict__ output,
317
+ const int rows,
318
+ const int cols) {
319
+ __shared__ float tile[TILE_DIM][TILE_DIM + 1];
320
+
321
+ int x = blockIdx.x * TILE_DIM + threadIdx.x;
322
+ int y = blockIdx.y * TILE_DIM + threadIdx.y;
323
+
324
+ for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
325
+ if (x < cols && (y + j) < rows) {
326
+ tile[threadIdx.y + j][threadIdx.x] = input[(y + j) * cols + x];
327
+ }
328
+ }
329
+ __syncthreads();
330
+
331
+ x = blockIdx.y * TILE_DIM + threadIdx.x;
332
+ y = blockIdx.x * TILE_DIM + threadIdx.y;
333
+
334
+ for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
335
+ if (x < rows && (y + j) < cols) {
336
+ output[(y + j) * rows + x] = tile[threadIdx.x][threadIdx.y + j];
337
+ }
338
+ }
339
+ }
340
+ CUDA
341
+ compile_cached(source, "transpose_2d")
342
+ end
343
+
344
+ # Copy a contiguous column range [col_off, col_off+len) from each row.
345
+ # dst[r, c] = src[r, col_off + c] (dst is [rows, len], src is [rows, total_cols]).
346
+ # Used to split [seq, embed] projections into per-head [seq, head_dim] slices.
347
+ # @return [Ignis::JIT::Kernel]
348
+ def slice_cols
349
+ source = <<~CUDA
350
+ extern "C" __global__
351
+ void slice_cols(const float* __restrict__ src,
352
+ float* __restrict__ dst,
353
+ const int rows, const int total_cols,
354
+ const int col_off, const int len) {
355
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
356
+ int total = rows * len;
357
+ if (idx < total) {
358
+ int r = idx / len;
359
+ int c = idx % len;
360
+ dst[idx] = src[r * total_cols + col_off + c];
361
+ }
362
+ }
363
+ CUDA
364
+ compile_cached(source, "slice_cols")
365
+ end
366
+
367
+ # Inverse of slice_cols: dst[r, col_off + c] = src[r, c].
368
+ # Used to scatter per-head [seq, head_dim] results back into [seq, embed].
369
+ # @return [Ignis::JIT::Kernel]
370
+ def scatter_cols
371
+ source = <<~CUDA
372
+ extern "C" __global__
373
+ void scatter_cols(const float* __restrict__ src,
374
+ float* __restrict__ dst,
375
+ const int rows, const int total_cols,
376
+ const int col_off, const int len) {
377
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
378
+ int total = rows * len;
379
+ if (idx < total) {
380
+ int r = idx / len;
381
+ int c = idx % len;
382
+ dst[r * total_cols + col_off + c] = src[idx];
383
+ }
384
+ }
385
+ CUDA
386
+ compile_cached(source, "scatter_cols")
387
+ end
388
+
389
+ # Accumulating scatter: dst[r, col_off + c] += src[r, c].
390
+ # Used for GQA backward, where the group_size query heads sharing one KV
391
+ # head each contribute to the same dK/dV columns — their gradients must
392
+ # SUM, not overwrite. (Columns are disjoint across rows, so no atomics
393
+ # are needed: each (r, col_off+c) is written by exactly one thread here;
394
+ # accumulation across heads happens via separate launches into the buffer.)
395
+ # @return [Ignis::JIT::Kernel]
396
+ def scatter_cols_add
397
+ source = <<~CUDA
398
+ extern "C" __global__
399
+ void scatter_cols_add(const float* __restrict__ src,
400
+ float* __restrict__ dst,
401
+ const int rows, const int total_cols,
402
+ const int col_off, const int len) {
403
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
404
+ int total = rows * len;
405
+ if (idx < total) {
406
+ int r = idx / len;
407
+ int c = idx % len;
408
+ dst[r * total_cols + col_off + c] += src[idx];
409
+ }
410
+ }
411
+ CUDA
412
+ compile_cached(source, "scatter_cols_add")
413
+ end
414
+
415
+ # Dequantize bfloat16 → float32 on-device. bf16 is exactly the top 16 bits
416
+ # of an fp32 value (same sign/exponent, truncated mantissa), so widening is
417
+ # lossless: float32_bits = uint16(bf16) << 16. Lets us load bf16 checkpoints
418
+ # (e.g. Llama) into fp32 weights without materializing a giant host array.
419
+ # @return [Ignis::JIT::Kernel]
420
+ def bf16_to_f32
421
+ source = <<~CUDA
422
+ extern "C" __global__
423
+ void bf16_to_f32(const unsigned short* __restrict__ src,
424
+ float* __restrict__ dst,
425
+ const int n) {
426
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
427
+ if (i < n) {
428
+ unsigned int bits = ((unsigned int)src[i]) << 16;
429
+ dst[i] = __uint_as_float(bits);
430
+ }
431
+ }
432
+ CUDA
433
+ compile_cached(source, "bf16_to_f32")
434
+ end
435
+
436
+ # Affine transform: output = input * scale + shift (fp32).
437
+ # Used e.g. to map cuRAND U[0,1) into U[low, high).
438
+ # @return [Ignis::JIT::Kernel]
439
+ def affine_forward
440
+ source = <<~CUDA
441
+ extern "C" __global__
442
+ void affine_forward(const float* __restrict__ input,
443
+ float* __restrict__ output,
444
+ const float scale, const float shift,
445
+ const int n) {
446
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
447
+ if (idx < n) {
448
+ output[idx] = input[idx] * scale + shift;
449
+ }
450
+ }
451
+ CUDA
452
+ compile_cached(source, "affine_forward")
453
+ end
454
+
455
+ # Row-broadcast bias add: out[r, c] = a[r, c] + bias[c]
456
+ # (a is [rows, cols], bias is [cols]). Linear layer bias.
457
+ # @return [Ignis::JIT::Kernel]
458
+ def add_bias_rows
459
+ source = <<~CUDA
460
+ extern "C" __global__
461
+ void add_bias_rows(const float* __restrict__ a,
462
+ const float* __restrict__ bias,
463
+ float* __restrict__ out,
464
+ const int rows, const int cols) {
465
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
466
+ int total = rows * cols;
467
+ if (idx < total) {
468
+ out[idx] = a[idx] + bias[idx % cols];
469
+ }
470
+ }
471
+ CUDA
472
+ compile_cached(source, "add_bias_rows")
473
+ end
474
+
475
+ private
476
+
477
+ # @param source [String] CUDA source code
478
+ # @param name [String] kernel function name
479
+ # @param device_id [Integer]
480
+ # @return [Ignis::JIT::Kernel]
481
+ def compile_cached(source, name, device_id: 0)
482
+ Ignis::JIT::Compiler.compile(source, name, device_id: device_id)
483
+ end
484
+ end
485
+ end
486
+ end
487
+ end
488
+ end
@@ -0,0 +1,213 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module JIT
5
+ module Kernels
6
+ # Loss function CUDA kernels for training.
7
+ # All are fused implementations for numerical stability and performance.
8
+ module Loss
9
+ class << self
10
+ # Fused cross-entropy forward: log_softmax + NLL in a single kernel
11
+ # Avoids materializing full log-softmax output
12
+ # @return [Ignis::JIT::Kernel]
13
+ def cross_entropy_forward
14
+ source = <<~CUDA
15
+ extern "C" __global__
16
+ void cross_entropy_forward(const float* __restrict__ logits,
17
+ const int* __restrict__ targets,
18
+ float* __restrict__ losses,
19
+ float* __restrict__ log_softmax_out,
20
+ const int batch_size,
21
+ const int vocab_size,
22
+ const float label_smoothing) {
23
+ int row = blockIdx.x * blockDim.x + threadIdx.x;
24
+ if (row < batch_size) {
25
+ const float* row_logits = logits + row * vocab_size;
26
+ float* row_lsm = log_softmax_out + row * vocab_size;
27
+ int target = targets[row];
28
+
29
+ // Find max for numerical stability
30
+ float max_val = row_logits[0];
31
+ for (int j = 1; j < vocab_size; j++) {
32
+ max_val = fmaxf(max_val, row_logits[j]);
33
+ }
34
+
35
+ // log_softmax = x - max - log(sum(exp(x - max)))
36
+ float log_sum_exp = 0.0f;
37
+ for (int j = 0; j < vocab_size; j++) {
38
+ log_sum_exp += expf(row_logits[j] - max_val);
39
+ }
40
+ log_sum_exp = logf(log_sum_exp);
41
+
42
+ // Compute log_softmax and store
43
+ for (int j = 0; j < vocab_size; j++) {
44
+ row_lsm[j] = row_logits[j] - max_val - log_sum_exp;
45
+ }
46
+
47
+ // NLL loss with optional label smoothing
48
+ if (label_smoothing > 0.0f) {
49
+ float smooth_loss = 0.0f;
50
+ for (int j = 0; j < vocab_size; j++) {
51
+ smooth_loss -= row_lsm[j];
52
+ }
53
+ smooth_loss /= (float)vocab_size;
54
+ float nll = -row_lsm[target];
55
+ losses[row] = (1.0f - label_smoothing) * nll + label_smoothing * smooth_loss;
56
+ } else {
57
+ losses[row] = -row_lsm[target];
58
+ }
59
+ }
60
+ }
61
+ CUDA
62
+ compile_cached(source, "cross_entropy_forward")
63
+ end
64
+
65
+ # Cross-entropy backward: softmax(logits) - one_hot(target)
66
+ # Combined softmax + gradient in one kernel
67
+ # @return [Ignis::JIT::Kernel]
68
+ def cross_entropy_backward
69
+ source = <<~CUDA
70
+ extern "C" __global__
71
+ void cross_entropy_backward(const float* __restrict__ log_softmax,
72
+ const int* __restrict__ targets,
73
+ const float* __restrict__ grad_output,
74
+ float* __restrict__ grad_logits,
75
+ const int batch_size,
76
+ const int vocab_size,
77
+ const float label_smoothing) {
78
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
79
+ int total = batch_size * vocab_size;
80
+ if (idx < total) {
81
+ int row = idx / vocab_size;
82
+ int col = idx % vocab_size;
83
+ int target = targets[row];
84
+
85
+ float softmax_val = expf(log_softmax[idx]);
86
+ float grad_scale = grad_output[row];
87
+
88
+ if (label_smoothing > 0.0f) {
89
+ float smooth_target = label_smoothing / (float)vocab_size;
90
+ float hard_target = (col == target) ? (1.0f - label_smoothing + smooth_target) : smooth_target;
91
+ grad_logits[idx] = grad_scale * (softmax_val - hard_target);
92
+ } else {
93
+ float indicator = (col == target) ? 1.0f : 0.0f;
94
+ grad_logits[idx] = grad_scale * (softmax_val - indicator);
95
+ }
96
+ }
97
+ }
98
+ CUDA
99
+ compile_cached(source, "cross_entropy_backward")
100
+ end
101
+
102
+ # MSE forward: (pred - target)^2, per element
103
+ # @return [Ignis::JIT::Kernel]
104
+ def mse_forward
105
+ source = <<~CUDA
106
+ extern "C" __global__
107
+ void mse_forward(const float* __restrict__ predictions,
108
+ const float* __restrict__ targets,
109
+ float* __restrict__ losses,
110
+ const int n) {
111
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
112
+ if (idx < n) {
113
+ float diff = predictions[idx] - targets[idx];
114
+ losses[idx] = diff * diff;
115
+ }
116
+ }
117
+ CUDA
118
+ compile_cached(source, "mse_forward")
119
+ end
120
+
121
+ # MSE backward: 2 * (pred - target) / n
122
+ # @return [Ignis::JIT::Kernel]
123
+ def mse_backward
124
+ source = <<~CUDA
125
+ extern "C" __global__
126
+ void mse_backward(const float* __restrict__ predictions,
127
+ const float* __restrict__ targets,
128
+ const float* __restrict__ grad_output,
129
+ float* __restrict__ grad_input,
130
+ const int n,
131
+ const float scale) {
132
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
133
+ if (idx < n) {
134
+ grad_input[idx] = grad_output[idx] * 2.0f * (predictions[idx] - targets[idx]) * scale;
135
+ }
136
+ }
137
+ CUDA
138
+ compile_cached(source, "mse_backward")
139
+ end
140
+
141
+ # Binary cross-entropy with logits: -[y*log(σ(x)) + (1-y)*log(1-σ(x))]
142
+ # @return [Ignis::JIT::Kernel]
143
+ def bce_forward
144
+ source = <<~CUDA
145
+ extern "C" __global__
146
+ void bce_forward(const float* __restrict__ logits,
147
+ const float* __restrict__ targets,
148
+ float* __restrict__ losses,
149
+ const int n) {
150
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
151
+ if (idx < n) {
152
+ float x = logits[idx];
153
+ float y = targets[idx];
154
+ // Numerically stable: max(x,0) - x*y + log(1+exp(-|x|))
155
+ float max_val = fmaxf(x, 0.0f);
156
+ losses[idx] = max_val - x * y + logf(1.0f + expf(-fabsf(x)));
157
+ }
158
+ }
159
+ CUDA
160
+ compile_cached(source, "bce_forward")
161
+ end
162
+
163
+ # BCE backward: σ(x) - y
164
+ # @return [Ignis::JIT::Kernel]
165
+ def bce_backward
166
+ source = <<~CUDA
167
+ extern "C" __global__
168
+ void bce_backward(const float* __restrict__ logits,
169
+ const float* __restrict__ targets,
170
+ const float* __restrict__ grad_output,
171
+ float* __restrict__ grad_input,
172
+ const int n) {
173
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
174
+ if (idx < n) {
175
+ float sig = 1.0f / (1.0f + expf(-logits[idx]));
176
+ grad_input[idx] = grad_output[idx] * (sig - targets[idx]);
177
+ }
178
+ }
179
+ CUDA
180
+ compile_cached(source, "bce_backward")
181
+ end
182
+
183
+ # Mean reduction: compute mean of array
184
+ # @return [Ignis::JIT::Kernel]
185
+ def mean_reduce
186
+ source = <<~CUDA
187
+ extern "C" __global__
188
+ void mean_reduce(const float* __restrict__ input,
189
+ float* __restrict__ output,
190
+ const int n) {
191
+ // Single-thread simple reduction (for loss scalar)
192
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
193
+ float sum = 0.0f;
194
+ for (int i = 0; i < n; i++) {
195
+ sum += input[i];
196
+ }
197
+ output[0] = sum / (float)n;
198
+ }
199
+ }
200
+ CUDA
201
+ compile_cached(source, "mean_reduce")
202
+ end
203
+
204
+ private
205
+
206
+ def compile_cached(source, name, device_id: 0)
207
+ Ignis::JIT::Compiler.compile(source, name, device_id: device_id)
208
+ end
209
+ end
210
+ end
211
+ end
212
+ end
213
+ end