ignis 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis.rb +94 -0
- data/lib/nnw/platform.rb +304 -0
- data/lib/nnw/shared/event_bus.rb +240 -0
- data/lib/nnw/shared/ffi_loader.rb +63 -0
- data/lib/nnw/shared/memory_contract.rb +204 -0
- data/lib/nnw/shared/nv_array.rb +710 -0
- data/lib/nnw/shared/recovery_protocol.rb +307 -0
- data/lib/nvruby/configuration.rb +217 -0
- data/lib/nvruby/cuda/device.rb +275 -0
- data/lib/nvruby/cuda/device_props.rb +202 -0
- data/lib/nvruby/cuda/graph.rb +265 -0
- data/lib/nvruby/cuda/graph_bindings.rb +119 -0
- data/lib/nvruby/cuda/library_loader.rb +285 -0
- data/lib/nvruby/cuda/memory.rb +410 -0
- data/lib/nvruby/cuda/runtime_api.rb +804 -0
- data/lib/nvruby/cuda/stream.rb +234 -0
- data/lib/nvruby/dtype.rb +139 -0
- data/lib/nvruby/epilogues.rb +438 -0
- data/lib/nvruby/errors.rb +303 -0
- data/lib/nvruby/half.rb +97 -0
- data/lib/nvruby/jit/compiled_kernel.rb +80 -0
- data/lib/nvruby/jit/compiler.rb +231 -0
- data/lib/nvruby/jit/driver_api_bindings.rb +363 -0
- data/lib/nvruby/jit/kernel.rb +240 -0
- data/lib/nvruby/jit/kernel_module.rb +133 -0
- data/lib/nvruby/jit/kernels/activations.rb +179 -0
- data/lib/nvruby/jit/kernels/attention.rb +504 -0
- data/lib/nvruby/jit/kernels/elementwise.rb +488 -0
- data/lib/nvruby/jit/kernels/loss.rb +213 -0
- data/lib/nvruby/jit/kernels/normalization.rb +200 -0
- data/lib/nvruby/jit/kernels/optimizer.rb +193 -0
- data/lib/nvruby/jit/nvrtc_bindings.rb +282 -0
- data/lib/nvruby/linalg/cublas_bindings.rb +295 -0
- data/lib/nvruby/linalg/cublaslt_bindings.rb +342 -0
- data/lib/nvruby/linalg/epilog.rb +67 -0
- data/lib/nvruby/linalg/matmul.rb +247 -0
- data/lib/nvruby/linalg/matmul_plan.rb +229 -0
- data/lib/nvruby/linalg/optimized_matmul.rb +412 -0
- data/lib/nvruby/memory/cuda_async_memory_resource.rb +123 -0
- data/lib/nvruby/memory/cuda_memory_resource.rb +68 -0
- data/lib/nvruby/memory/device_memory_resource.rb +106 -0
- data/lib/nvruby/memory/pinned_host_memory_resource.rb +112 -0
- data/lib/nvruby/memory/pool_memory_resource.rb +242 -0
- data/lib/nvruby/memory/stats.rb +107 -0
- data/lib/nvruby/memory.rb +124 -0
- data/lib/nvruby/version.rb +5 -0
- metadata +108 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module JIT
|
|
5
|
+
module Kernels
|
|
6
|
+
# Attention and softmax CUDA kernels for transformer models.
|
|
7
|
+
# Includes numerically stable softmax, top-k/top-p filtering.
|
|
8
|
+
module Attention
|
|
9
|
+
class << self
|
|
10
|
+
# Numerically stable softmax forward along last dimension.
|
|
11
|
+
# Uses online max + sum trick for stability.
|
|
12
|
+
# Input shape: [outer_size, dim_size], softmax along dim_size
|
|
13
|
+
# @return [Ignis::JIT::Kernel]
|
|
14
|
+
def softmax_forward
|
|
15
|
+
source = <<~CUDA
|
|
16
|
+
extern "C" __global__
|
|
17
|
+
void softmax_forward(const float* __restrict__ input,
|
|
18
|
+
float* __restrict__ output,
|
|
19
|
+
const int outer_size,
|
|
20
|
+
const int dim_size) {
|
|
21
|
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
22
|
+
if (row < outer_size) {
|
|
23
|
+
const float* in_row = input + row * dim_size;
|
|
24
|
+
float* out_row = output + row * dim_size;
|
|
25
|
+
|
|
26
|
+
// Find max for numerical stability
|
|
27
|
+
float max_val = in_row[0];
|
|
28
|
+
for (int j = 1; j < dim_size; j++) {
|
|
29
|
+
max_val = fmaxf(max_val, in_row[j]);
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// Compute exp(x - max) and sum
|
|
33
|
+
float sum = 0.0f;
|
|
34
|
+
for (int j = 0; j < dim_size; j++) {
|
|
35
|
+
float e = expf(in_row[j] - max_val);
|
|
36
|
+
out_row[j] = e;
|
|
37
|
+
sum += e;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// Normalize
|
|
41
|
+
float inv_sum = 1.0f / sum;
|
|
42
|
+
for (int j = 0; j < dim_size; j++) {
|
|
43
|
+
out_row[j] *= inv_sum;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
CUDA
|
|
48
|
+
compile_cached(source, "softmax_forward")
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Rotary Position Embedding (RoPE), HF/Llama/Qwen "rotate_half" convention.
|
|
52
|
+
# Input x is [seq, n_heads*head_dim] (heads contiguous). For each head, the
|
|
53
|
+
# first/second halves form rotation pairs: with half = head_dim/2,
|
|
54
|
+
# inv_freq(i) = base^(-2i/head_dim), angle = (row + pos_offset) * inv_freq(i)
|
|
55
|
+
# d < half: out[d] = x[d]*cos - x[d+half]*sin
|
|
56
|
+
# d >= half: out[d] = x[d]*cos + x[d-half]*sin (i = d-half)
|
|
57
|
+
# The rotation is orthogonal, so the BACKWARD is this same kernel with the
|
|
58
|
+
# sin sign flipped (R^T = R(-θ)); callers pass sin_sign = +1 fwd, -1 bwd.
|
|
59
|
+
# pos_offset lets decode rotate a single new token at its absolute position.
|
|
60
|
+
# @return [Ignis::JIT::Kernel]
|
|
61
|
+
def rope_apply
|
|
62
|
+
source = <<~CUDA
|
|
63
|
+
extern "C" __global__
|
|
64
|
+
void rope_apply(const float* __restrict__ x,
|
|
65
|
+
float* __restrict__ out,
|
|
66
|
+
const int seq, const int n_heads, const int head_dim,
|
|
67
|
+
const int pos_offset,
|
|
68
|
+
const float* __restrict__ inv_freq, // [head_dim/2] precomputed freqs
|
|
69
|
+
const float sin_sign) {
|
|
70
|
+
int total = seq * n_heads * head_dim;
|
|
71
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
72
|
+
if (idx >= total) return;
|
|
73
|
+
|
|
74
|
+
int d = idx % head_dim;
|
|
75
|
+
int row = (idx / head_dim) / n_heads; // sequence position within this call
|
|
76
|
+
int half = head_dim / 2;
|
|
77
|
+
int pos = row + pos_offset;
|
|
78
|
+
|
|
79
|
+
// Precomputed inv_freq lets the caller apply RoPE scaling (NTK/llama3/
|
|
80
|
+
// YaRN) by remapping frequencies on the host; standard RoPE just passes
|
|
81
|
+
// base^(-2i/head_dim).
|
|
82
|
+
int freq_idx = (d < half) ? d : (d - half);
|
|
83
|
+
float angle = (float)pos * inv_freq[freq_idx];
|
|
84
|
+
float c = cosf(angle);
|
|
85
|
+
float s = sinf(angle) * sin_sign;
|
|
86
|
+
|
|
87
|
+
float xd = x[idx];
|
|
88
|
+
if (d < half) {
|
|
89
|
+
out[idx] = xd * c - x[idx + half] * s;
|
|
90
|
+
} else {
|
|
91
|
+
out[idx] = xd * c + x[idx - half] * s;
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
CUDA
|
|
95
|
+
compile_cached(source, "rope_apply")
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Softmax backward: Jacobian-vector product
|
|
99
|
+
# grad_input[i] = softmax[i] * (grad_output[i] - sum(grad_output * softmax))
|
|
100
|
+
# @return [Ignis::JIT::Kernel]
|
|
101
|
+
def softmax_backward
|
|
102
|
+
source = <<~CUDA
|
|
103
|
+
extern "C" __global__
|
|
104
|
+
void softmax_backward(const float* __restrict__ grad_output,
|
|
105
|
+
const float* __restrict__ softmax_output,
|
|
106
|
+
float* __restrict__ grad_input,
|
|
107
|
+
const int outer_size,
|
|
108
|
+
const int dim_size) {
|
|
109
|
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
110
|
+
if (row < outer_size) {
|
|
111
|
+
const float* go = grad_output + row * dim_size;
|
|
112
|
+
const float* so = softmax_output + row * dim_size;
|
|
113
|
+
float* gi = grad_input + row * dim_size;
|
|
114
|
+
|
|
115
|
+
// dot(grad_output, softmax_output)
|
|
116
|
+
float dot = 0.0f;
|
|
117
|
+
for (int j = 0; j < dim_size; j++) {
|
|
118
|
+
dot += go[j] * so[j];
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// grad_input = softmax * (grad_output - dot)
|
|
122
|
+
for (int j = 0; j < dim_size; j++) {
|
|
123
|
+
gi[j] = so[j] * (go[j] - dot);
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
CUDA
|
|
128
|
+
compile_cached(source, "softmax_backward")
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
# Top-k mask: zero out all logits except the top-k highest values.
|
|
132
|
+
# Used for top-k sampling in text generation.
|
|
133
|
+
# @return [Ignis::JIT::Kernel]
|
|
134
|
+
def topk_mask
|
|
135
|
+
source = <<~CUDA
|
|
136
|
+
extern "C" __global__
|
|
137
|
+
void topk_mask(float* __restrict__ logits,
|
|
138
|
+
const int vocab_size,
|
|
139
|
+
const int k) {
|
|
140
|
+
// Single-row operation (batch dim handled by caller)
|
|
141
|
+
int row = blockIdx.x;
|
|
142
|
+
float* row_logits = logits + row * vocab_size;
|
|
143
|
+
|
|
144
|
+
// Find k-th largest value using partial selection
|
|
145
|
+
// Simple approach: sort-like pass to find threshold
|
|
146
|
+
float threshold = -1e20f;
|
|
147
|
+
for (int i = 0; i < k; i++) {
|
|
148
|
+
float max_val = -1e20f;
|
|
149
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
150
|
+
if (row_logits[j] > max_val && (i == 0 || row_logits[j] < threshold || (row_logits[j] == threshold))) {
|
|
151
|
+
// On first pass, find the max
|
|
152
|
+
// On subsequent passes, find next highest
|
|
153
|
+
if (i == 0 || row_logits[j] <= threshold) {
|
|
154
|
+
// Need a smarter approach for GPU
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
// Simpler: find the k-th largest via partial sort
|
|
161
|
+
// For GPU efficiency, we use a different strategy:
|
|
162
|
+
// 1. Copy values, sort descending, get threshold at index k-1
|
|
163
|
+
// 2. Mask below threshold
|
|
164
|
+
// This kernel uses a simple nth-element approach
|
|
165
|
+
extern __shared__ float shared_vals[];
|
|
166
|
+
if (threadIdx.x == 0) {
|
|
167
|
+
// Copy to shared memory
|
|
168
|
+
for (int j = 0; j < vocab_size && j < 65536; j++) {
|
|
169
|
+
shared_vals[j] = row_logits[j];
|
|
170
|
+
}
|
|
171
|
+
// Find k-th largest via partial sort (insertion sort on top-k)
|
|
172
|
+
float kth = -1e20f;
|
|
173
|
+
float top_vals[256]; // Max k of 256
|
|
174
|
+
int actual_k = k < 256 ? k : 256;
|
|
175
|
+
for (int i = 0; i < actual_k; i++) top_vals[i] = -1e20f;
|
|
176
|
+
|
|
177
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
178
|
+
float v = row_logits[j];
|
|
179
|
+
if (v > top_vals[actual_k - 1]) {
|
|
180
|
+
top_vals[actual_k - 1] = v;
|
|
181
|
+
// Insertion sort step
|
|
182
|
+
for (int m = actual_k - 1; m > 0 && top_vals[m] > top_vals[m-1]; m--) {
|
|
183
|
+
float tmp = top_vals[m];
|
|
184
|
+
top_vals[m] = top_vals[m-1];
|
|
185
|
+
top_vals[m-1] = tmp;
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
kth = top_vals[actual_k - 1];
|
|
190
|
+
|
|
191
|
+
// Mask logits below threshold
|
|
192
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
193
|
+
if (row_logits[j] < kth) {
|
|
194
|
+
row_logits[j] = -1e20f;
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
CUDA
|
|
200
|
+
compile_cached(source, "topk_mask")
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
# Top-p (nucleus) mask: keep smallest set of tokens with cumulative prob >= p
|
|
204
|
+
# Assumes logits have already been softmaxed into probabilities.
|
|
205
|
+
# @return [Ignis::JIT::Kernel]
|
|
206
|
+
def topp_mask
|
|
207
|
+
source = <<~CUDA
|
|
208
|
+
extern "C" __global__
|
|
209
|
+
void topp_mask(float* __restrict__ probs,
|
|
210
|
+
const int vocab_size,
|
|
211
|
+
const float p) {
|
|
212
|
+
int row = blockIdx.x;
|
|
213
|
+
float* row_probs = probs + row * vocab_size;
|
|
214
|
+
|
|
215
|
+
if (threadIdx.x == 0) {
|
|
216
|
+
// Simple CPU-style approach for correctness
|
|
217
|
+
// Find cumulative threshold
|
|
218
|
+
// 1. Sort indices by probability descending
|
|
219
|
+
// 2. Compute cumulative sum
|
|
220
|
+
// 3. Zero everything after cumsum > p
|
|
221
|
+
|
|
222
|
+
// Using insertion sort on indices (vocab_size typically 50257)
|
|
223
|
+
// For production, use radix sort kernel
|
|
224
|
+
float cumsum = 0.0f;
|
|
225
|
+
float threshold = 0.0f;
|
|
226
|
+
|
|
227
|
+
// Find cumulative prob threshold
|
|
228
|
+
// Simple O(n*k) approach: repeatedly find max and accumulate
|
|
229
|
+
bool* mask = (bool*)malloc(vocab_size * sizeof(bool));
|
|
230
|
+
if (mask) {
|
|
231
|
+
for (int j = 0; j < vocab_size; j++) mask[j] = false;
|
|
232
|
+
|
|
233
|
+
while (cumsum < p) {
|
|
234
|
+
float max_val = -1.0f;
|
|
235
|
+
int max_idx = -1;
|
|
236
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
237
|
+
if (!mask[j] && row_probs[j] > max_val) {
|
|
238
|
+
max_val = row_probs[j];
|
|
239
|
+
max_idx = j;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
if (max_idx < 0) break;
|
|
243
|
+
mask[max_idx] = true;
|
|
244
|
+
cumsum += max_val;
|
|
245
|
+
threshold = max_val;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
// Zero out non-selected
|
|
249
|
+
for (int j = 0; j < vocab_size; j++) {
|
|
250
|
+
if (!mask[j]) row_probs[j] = 0.0f;
|
|
251
|
+
}
|
|
252
|
+
free(mask);
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
CUDA
|
|
257
|
+
compile_cached(source, "topp_mask")
|
|
258
|
+
end
|
|
259
|
+
|
|
260
|
+
# Scaled dot-product attention score: score = Q @ K^T / sqrt(d_k)
|
|
261
|
+
# With optional causal mask (upper triangular set to -inf)
|
|
262
|
+
# @return [Ignis::JIT::Kernel]
|
|
263
|
+
def attention_score
|
|
264
|
+
source = <<~CUDA
|
|
265
|
+
extern "C" __global__
|
|
266
|
+
void attention_score(const float* __restrict__ scores,
|
|
267
|
+
float* __restrict__ masked_scores,
|
|
268
|
+
const float scale,
|
|
269
|
+
const int seq_len,
|
|
270
|
+
const int use_causal_mask) {
|
|
271
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
272
|
+
int total = seq_len * seq_len;
|
|
273
|
+
if (idx < total) {
|
|
274
|
+
int row = idx / seq_len;
|
|
275
|
+
int col = idx % seq_len;
|
|
276
|
+
|
|
277
|
+
float val = scores[idx] * scale;
|
|
278
|
+
|
|
279
|
+
// Causal mask: zero out future positions
|
|
280
|
+
if (use_causal_mask && col > row) {
|
|
281
|
+
val = -1e9f;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
masked_scores[idx] = val;
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
CUDA
|
|
288
|
+
compile_cached(source, "attention_score")
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
# Flash Attention 2 forward (Dao et al. 2023).
|
|
292
|
+
# Tiled Q/K/V processing — avoids materializing full N×N attention matrix.
|
|
293
|
+
# O(N) memory vs O(N²) for standard attention.
|
|
294
|
+
# Uses online softmax (streaming max + sum) for numerical stability.
|
|
295
|
+
# @return [Ignis::JIT::Kernel]
|
|
296
|
+
def flash_attention_forward
|
|
297
|
+
source = <<~CUDA
|
|
298
|
+
#define TILE_SIZE 64
|
|
299
|
+
#define HEAD_DIM_MAX 128
|
|
300
|
+
|
|
301
|
+
extern "C" __global__
|
|
302
|
+
void flash_attention_forward(
|
|
303
|
+
const float* __restrict__ Q,
|
|
304
|
+
const float* __restrict__ K,
|
|
305
|
+
const float* __restrict__ V,
|
|
306
|
+
float* __restrict__ O,
|
|
307
|
+
const int seq_len,
|
|
308
|
+
const int head_dim,
|
|
309
|
+
const float scale,
|
|
310
|
+
const int use_causal_mask) {
|
|
311
|
+
|
|
312
|
+
// Each block handles one query tile
|
|
313
|
+
int q_tile_idx = blockIdx.x;
|
|
314
|
+
int q_start = q_tile_idx * TILE_SIZE;
|
|
315
|
+
int tid = threadIdx.x;
|
|
316
|
+
|
|
317
|
+
if (q_start + tid >= seq_len) return;
|
|
318
|
+
|
|
319
|
+
// Per-thread accumulators for online softmax
|
|
320
|
+
float row_max = -1e20f;
|
|
321
|
+
float row_sum = 0.0f;
|
|
322
|
+
float acc[HEAD_DIM_MAX];
|
|
323
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
324
|
+
acc[d] = 0.0f;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
int q_idx = q_start + tid;
|
|
328
|
+
|
|
329
|
+
// Load Q row into registers
|
|
330
|
+
float q_row[HEAD_DIM_MAX];
|
|
331
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
332
|
+
q_row[d] = Q[q_idx * head_dim + d];
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
// Iterate over K/V tiles
|
|
336
|
+
int num_kv_tiles = (seq_len + TILE_SIZE - 1) / TILE_SIZE;
|
|
337
|
+
for (int kv_tile = 0; kv_tile < num_kv_tiles; kv_tile++) {
|
|
338
|
+
int kv_start = kv_tile * TILE_SIZE;
|
|
339
|
+
|
|
340
|
+
// For each key in this tile, compute attention score
|
|
341
|
+
for (int kj = 0; kj < TILE_SIZE; kj++) {
|
|
342
|
+
int k_idx = kv_start + kj;
|
|
343
|
+
if (k_idx >= seq_len) break;
|
|
344
|
+
|
|
345
|
+
// Causal: skip future positions
|
|
346
|
+
if (use_causal_mask && k_idx > q_idx) continue;
|
|
347
|
+
|
|
348
|
+
// Dot product Q[q_idx] · K[k_idx]
|
|
349
|
+
float score = 0.0f;
|
|
350
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
351
|
+
score += q_row[d] * K[k_idx * head_dim + d];
|
|
352
|
+
}
|
|
353
|
+
score *= scale;
|
|
354
|
+
|
|
355
|
+
// Online softmax update (Milakov & Gimelshein)
|
|
356
|
+
float new_max = fmaxf(row_max, score);
|
|
357
|
+
float exp_diff = expf(row_max - new_max);
|
|
358
|
+
float exp_score = expf(score - new_max);
|
|
359
|
+
|
|
360
|
+
// Rescale running accumulator
|
|
361
|
+
float new_sum = row_sum * exp_diff + exp_score;
|
|
362
|
+
|
|
363
|
+
// Update output accumulator: rescale old + add new V contribution
|
|
364
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
365
|
+
acc[d] = acc[d] * exp_diff + exp_score * V[k_idx * head_dim + d];
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
row_max = new_max;
|
|
369
|
+
row_sum = new_sum;
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
// Final normalization: divide accumulated values by total softmax sum
|
|
374
|
+
if (row_sum > 0.0f) {
|
|
375
|
+
float inv_sum = 1.0f / row_sum;
|
|
376
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
377
|
+
O[q_idx * head_dim + d] = acc[d] * inv_sum;
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
}
|
|
381
|
+
CUDA
|
|
382
|
+
compile_cached(source, "flash_attention_forward")
|
|
383
|
+
end
|
|
384
|
+
|
|
385
|
+
# Flash Attention 2 backward.
|
|
386
|
+
# Recomputes attention weights on-the-fly during backward (memory efficient).
|
|
387
|
+
# @return [Ignis::JIT::Kernel]
|
|
388
|
+
def flash_attention_backward
|
|
389
|
+
source = <<~CUDA
|
|
390
|
+
#define TILE_SIZE 64
|
|
391
|
+
#define HEAD_DIM_MAX 128
|
|
392
|
+
|
|
393
|
+
extern "C" __global__
|
|
394
|
+
void flash_attention_backward(
|
|
395
|
+
const float* __restrict__ Q,
|
|
396
|
+
const float* __restrict__ K,
|
|
397
|
+
const float* __restrict__ V,
|
|
398
|
+
const float* __restrict__ O,
|
|
399
|
+
const float* __restrict__ dO,
|
|
400
|
+
float* __restrict__ dQ,
|
|
401
|
+
float* __restrict__ dK,
|
|
402
|
+
float* __restrict__ dV,
|
|
403
|
+
const int seq_len,
|
|
404
|
+
const int head_dim,
|
|
405
|
+
const float scale,
|
|
406
|
+
const int use_causal_mask) {
|
|
407
|
+
|
|
408
|
+
int q_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
409
|
+
if (q_idx >= seq_len) return;
|
|
410
|
+
|
|
411
|
+
// Load Q row and dO row
|
|
412
|
+
float q_row[HEAD_DIM_MAX], do_row[HEAD_DIM_MAX], o_row[HEAD_DIM_MAX];
|
|
413
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
414
|
+
q_row[d] = Q[q_idx * head_dim + d];
|
|
415
|
+
do_row[d] = dO[q_idx * head_dim + d];
|
|
416
|
+
o_row[d] = O[q_idx * head_dim + d];
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
// Compute D_i = sum(dO_i * O_i) for this row
|
|
420
|
+
float D_i = 0.0f;
|
|
421
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
422
|
+
D_i += do_row[d] * o_row[d];
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
// Recompute attention: need softmax weights
|
|
426
|
+
// First pass: compute row_max and row_sum
|
|
427
|
+
float row_max = -1e20f;
|
|
428
|
+
for (int k_idx = 0; k_idx < seq_len; k_idx++) {
|
|
429
|
+
if (use_causal_mask && k_idx > q_idx) continue;
|
|
430
|
+
float score = 0.0f;
|
|
431
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
432
|
+
score += q_row[d] * K[k_idx * head_dim + d];
|
|
433
|
+
}
|
|
434
|
+
score *= scale;
|
|
435
|
+
row_max = fmaxf(row_max, score);
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
float row_sum = 0.0f;
|
|
439
|
+
for (int k_idx = 0; k_idx < seq_len; k_idx++) {
|
|
440
|
+
if (use_causal_mask && k_idx > q_idx) continue;
|
|
441
|
+
float score = 0.0f;
|
|
442
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
443
|
+
score += q_row[d] * K[k_idx * head_dim + d];
|
|
444
|
+
}
|
|
445
|
+
score *= scale;
|
|
446
|
+
row_sum += expf(score - row_max);
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
// Second pass: compute gradients
|
|
450
|
+
float dq_acc[HEAD_DIM_MAX];
|
|
451
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) dq_acc[d] = 0.0f;
|
|
452
|
+
|
|
453
|
+
for (int k_idx = 0; k_idx < seq_len; k_idx++) {
|
|
454
|
+
if (use_causal_mask && k_idx > q_idx) continue;
|
|
455
|
+
|
|
456
|
+
float score = 0.0f;
|
|
457
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
458
|
+
score += q_row[d] * K[k_idx * head_dim + d];
|
|
459
|
+
}
|
|
460
|
+
score *= scale;
|
|
461
|
+
float p_ij = expf(score - row_max) / row_sum;
|
|
462
|
+
|
|
463
|
+
// dV += p_ij * dO
|
|
464
|
+
// dP = dO @ V^T
|
|
465
|
+
float dP = 0.0f;
|
|
466
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
467
|
+
atomicAdd(&dV[k_idx * head_dim + d], p_ij * do_row[d]);
|
|
468
|
+
dP += do_row[d] * V[k_idx * head_dim + d];
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
// dS = p_ij * (dP - D_i) * scale
|
|
472
|
+
float dS = p_ij * (dP - D_i) * scale;
|
|
473
|
+
|
|
474
|
+
// dQ += dS * K[k_idx]
|
|
475
|
+
// dK[k_idx] += dS * Q[q_idx]
|
|
476
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
477
|
+
dq_acc[d] += dS * K[k_idx * head_dim + d];
|
|
478
|
+
atomicAdd(&dK[k_idx * head_dim + d], dS * q_row[d]);
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
// Write dQ
|
|
483
|
+
for (int d = 0; d < head_dim && d < HEAD_DIM_MAX; d++) {
|
|
484
|
+
dQ[q_idx * head_dim + d] = dq_acc[d];
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
CUDA
|
|
488
|
+
compile_cached(source, "flash_attention_backward")
|
|
489
|
+
end
|
|
490
|
+
|
|
491
|
+
private
|
|
492
|
+
|
|
493
|
+
# @param source [String]
|
|
494
|
+
# @param name [String]
|
|
495
|
+
# @param device_id [Integer]
|
|
496
|
+
# @return [Ignis::JIT::Kernel]
|
|
497
|
+
def compile_cached(source, name, device_id: 0)
|
|
498
|
+
Ignis::JIT::Compiler.compile(source, name, device_id: device_id)
|
|
499
|
+
end
|
|
500
|
+
end
|
|
501
|
+
end
|
|
502
|
+
end
|
|
503
|
+
end
|
|
504
|
+
end
|