liger-kernel-nightly 0.5.10.dev20250605210201__py3-none-any.whl → 0.5.10.dev20250605224739__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  2. liger_kernel/transformers/functional.py +28 -0
  3. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  4. liger_kernel/transformers/model/gemma.py +5 -4
  5. liger_kernel/transformers/model/gemma2.py +7 -4
  6. liger_kernel/transformers/model/glm4.py +5 -4
  7. liger_kernel/transformers/model/llama.py +5 -4
  8. liger_kernel/transformers/model/mistral.py +5 -4
  9. liger_kernel/transformers/model/mixtral.py +5 -4
  10. liger_kernel/transformers/model/mllama.py +5 -4
  11. liger_kernel/transformers/model/olmo2.py +5 -4
  12. liger_kernel/transformers/model/phi3.py +5 -4
  13. liger_kernel/transformers/model/qwen2.py +5 -4
  14. liger_kernel/transformers/model/qwen2_5_vl.py +4 -3
  15. liger_kernel/transformers/model/qwen2_vl.py +4 -3
  16. liger_kernel/transformers/model/qwen3_moe.py +5 -4
  17. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/METADATA +1 -1
  18. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/RECORD +22 -20
  19. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/LICENSE +0 -0
  20. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/WHEEL +0 -0
  22. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1022 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.softmax import _softmax_backward
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.utils import calculate_settings
10
+ from liger_kernel.ops.utils import ensure_contiguous
11
+
12
+
13
+ @triton.jit
14
+ def _neighborhood_mask_kernel(
15
+ mask_ptr,
16
+ seq_len: tl.constexpr,
17
+ kernel_size: tl.constexpr,
18
+ dilation: tl.constexpr,
19
+ BLOCK_SIZE: tl.constexpr,
20
+ num_stages: tl.constexpr,
21
+ num_warps: tl.constexpr,
22
+ ):
23
+ """
24
+ Generate a neighborhood attention mask for a given sequence.
25
+
26
+ This kernel creates a binary mask that defines which positions in a sequence
27
+ can attend to each other based on a neighborhood window with optional dilation.
28
+ Each row of the mask corresponds to a query position, and each column indicates
29
+ whether that key position is within the allowed neighborhood.
30
+
31
+ The neighborhood is defined as positions within kernel_size//2 * dilation distance
32
+ from the center position. When dilation > 1, only positions at multiples of the
33
+ dilation factor are included in the neighborhood.
34
+
35
+ Args:
36
+ mask_ptr: Pointer to the output mask tensor [seq_len, seq_len]
37
+ seq_len: Length of the input sequence
38
+ kernel_size: Size of the neighborhood window (must be odd)
39
+ dilation: Dilation factor for the neighborhood pattern
40
+ BLOCK_SIZE: Block size for processing (compile-time constant)
41
+ num_stages: Number of pipeline stages (compile-time constant)
42
+ num_warps: Number of warps (compile-time constant)
43
+
44
+ Grid: (seq_len,)
45
+ Each program processes one row of the mask matrix.
46
+ """
47
+ row_id = tl.program_id(0)
48
+
49
+ center = row_id
50
+ half_kernel = kernel_size // 2
51
+
52
+ start = tl.maximum(0, center - half_kernel * dilation)
53
+ end = tl.minimum(seq_len, center + half_kernel * dilation + 1)
54
+
55
+ col_offsets = tl.arange(0, BLOCK_SIZE)
56
+ mask = col_offsets < seq_len
57
+
58
+ valid_neighbors = (col_offsets >= start) & (col_offsets < end)
59
+ if dilation > 1:
60
+ relative_pos = col_offsets - center
61
+ valid_dilation = (relative_pos % dilation) == 0
62
+ valid_neighbors = valid_neighbors & valid_dilation
63
+
64
+ mask_values = tl.where(valid_neighbors & mask, 1.0, 0.0)
65
+
66
+ base_offset = row_id * seq_len
67
+ tl.store(mask_ptr + base_offset + col_offsets, mask_values, mask=mask)
68
+
69
+
70
+ @triton.jit
71
+ def _fused_neighborhood_attention_qk_kernel(
72
+ Q_ptr,
73
+ K_ptr,
74
+ QK_ptr,
75
+ mask_ptr,
76
+ q_batch_stride,
77
+ q_head_stride,
78
+ q_seq_stride,
79
+ q_dim_stride,
80
+ k_batch_stride,
81
+ k_head_stride,
82
+ k_seq_stride,
83
+ k_dim_stride,
84
+ qk_batch_stride,
85
+ qk_head_stride,
86
+ qk_seq_stride,
87
+ qk_seq2_stride,
88
+ batch_size: tl.constexpr,
89
+ num_heads: tl.constexpr,
90
+ seq_len: tl.constexpr,
91
+ head_dim: tl.constexpr,
92
+ scale: tl.constexpr,
93
+ kernel_size: tl.constexpr,
94
+ dilation: tl.constexpr,
95
+ BLOCK_SIZE_M: tl.constexpr,
96
+ BLOCK_SIZE_N: tl.constexpr,
97
+ BLOCK_SIZE_K: tl.constexpr,
98
+ num_stages: tl.constexpr,
99
+ num_warps: tl.constexpr,
100
+ ):
101
+ """
102
+ Compute Q @ K^T with neighborhood masking and scaling.
103
+
104
+ This kernel performs the first stage of neighborhood attention by computing
105
+ the attention scores between queries and keys, applying scaling, and masking
106
+ positions outside the neighborhood window. The result is a matrix of attention
107
+ scores ready for softmax normalization.
108
+
109
+ The computation is tiled across sequence dimensions for memory efficiency.
110
+ Each tile computes a block of the attention score matrix by iterating over
111
+ the head dimension and accumulating dot products.
112
+
113
+ Args:
114
+ Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
115
+ K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
116
+ QK_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, seq_len]
117
+ mask_ptr: Pointer to neighborhood mask [seq_len, seq_len]
118
+ q_*_stride: Strides for query tensor
119
+ k_*_stride: Strides for key tensor
120
+ qk_*_stride: Strides for output tensor
121
+ batch_size: Number of batches
122
+ num_heads: Number of attention heads
123
+ seq_len: Sequence length
124
+ head_dim: Dimension of each attention head
125
+ scale: Scaling factor for attention scores (typically 1/sqrt(head_dim))
126
+ kernel_size: Size of the neighborhood window
127
+ dilation: Dilation factor for the neighborhood
128
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
129
+ BLOCK_SIZE_N: Block size for sequence dimension (cols)
130
+ BLOCK_SIZE_K: Block size for head dimension
131
+ num_stages: Number of pipeline stages
132
+ num_warps: Number of warps
133
+
134
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
135
+ Each program computes a tile of the attention score matrix.
136
+ """
137
+ batch_head_id = tl.program_id(0)
138
+ tile_m = tl.program_id(1)
139
+ tile_n = tl.program_id(2)
140
+
141
+ batch_id = batch_head_id // num_heads
142
+ head_id = batch_head_id % num_heads
143
+
144
+ row_start = tile_m * BLOCK_SIZE_M
145
+ col_start = tile_n * BLOCK_SIZE_N
146
+
147
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
148
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
149
+
150
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
151
+
152
+ for k_start in range(0, head_dim, BLOCK_SIZE_K):
153
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
154
+ k_mask = k_offsets < head_dim
155
+
156
+ q_ptrs = (
157
+ Q_ptr
158
+ + batch_id * q_batch_stride
159
+ + head_id * q_head_stride
160
+ + row_offsets[:, None] * q_seq_stride
161
+ + k_offsets[None, :] * q_dim_stride
162
+ )
163
+ q_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
164
+ q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
165
+
166
+ k_ptrs = (
167
+ K_ptr
168
+ + batch_id * k_batch_stride
169
+ + head_id * k_head_stride
170
+ + col_offsets[:, None] * k_seq_stride
171
+ + k_offsets[None, :] * k_dim_stride
172
+ )
173
+ k_mask = (col_offsets[:, None] < seq_len) & k_mask[None, :]
174
+ k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0)
175
+
176
+ acc += tl.dot(q_chunk, tl.trans(k_chunk))
177
+
178
+ acc = acc * scale
179
+
180
+ mask_ptrs = mask_ptr + row_offsets[:, None] * seq_len + col_offsets[None, :]
181
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
182
+ neighborhood_mask = tl.load(mask_ptrs, mask=valid_mask, other=0.0)
183
+
184
+ acc = tl.where(neighborhood_mask > 0.0, acc, float("-inf"))
185
+
186
+ qk_ptrs = (
187
+ QK_ptr
188
+ + batch_id * qk_batch_stride
189
+ + head_id * qk_head_stride
190
+ + row_offsets[:, None] * qk_seq_stride
191
+ + col_offsets[None, :] * qk_seq2_stride
192
+ )
193
+ tl.store(qk_ptrs, acc, mask=valid_mask)
194
+
195
+
196
+ @triton.jit
197
+ def _fused_neighborhood_attention_av_kernel(
198
+ Attn_ptr,
199
+ V_ptr,
200
+ Out_ptr,
201
+ attn_batch_stride,
202
+ attn_head_stride,
203
+ attn_seq_stride,
204
+ attn_seq2_stride,
205
+ v_batch_stride,
206
+ v_head_stride,
207
+ v_seq_stride,
208
+ v_dim_stride,
209
+ out_batch_stride,
210
+ out_head_stride,
211
+ out_seq_stride,
212
+ out_dim_stride,
213
+ batch_size: tl.constexpr,
214
+ num_heads: tl.constexpr,
215
+ seq_len: tl.constexpr,
216
+ head_dim: tl.constexpr,
217
+ BLOCK_SIZE_M: tl.constexpr,
218
+ BLOCK_SIZE_N: tl.constexpr,
219
+ BLOCK_SIZE_K: tl.constexpr,
220
+ num_stages: tl.constexpr,
221
+ num_warps: tl.constexpr,
222
+ ):
223
+ """
224
+ Compute Attention @ V to produce the final output.
225
+
226
+ This kernel performs the second stage of neighborhood attention by multiplying
227
+ the normalized attention weights with the value matrix. The computation is
228
+ tiled for memory efficiency, with each tile computing a block of the output.
229
+
230
+ Args:
231
+ Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
232
+ V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
233
+ Out_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, head_dim]
234
+ attn_*_stride: Strides for attention weights tensor
235
+ v_*_stride: Strides for value tensor
236
+ out_*_stride: Strides for output tensor
237
+ batch_size: Number of batches
238
+ num_heads: Number of attention heads
239
+ seq_len: Sequence length
240
+ head_dim: Dimension of each attention head
241
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
242
+ BLOCK_SIZE_N: Block size for head dimension (cols)
243
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
244
+ num_stages: Number of pipeline stages
245
+ num_warps: Number of warps
246
+
247
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
248
+ Each program computes a tile of the output matrix.
249
+ """
250
+ batch_head_id = tl.program_id(0)
251
+ tile_m = tl.program_id(1)
252
+ tile_n = tl.program_id(2)
253
+
254
+ batch_id = batch_head_id // num_heads
255
+ head_id = batch_head_id % num_heads
256
+
257
+ row_start = tile_m * BLOCK_SIZE_M
258
+ col_start = tile_n * BLOCK_SIZE_N
259
+
260
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
261
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
262
+
263
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
264
+
265
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
266
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
267
+ k_mask = k_offsets < seq_len
268
+
269
+ attn_ptrs = (
270
+ Attn_ptr
271
+ + batch_id * attn_batch_stride
272
+ + head_id * attn_head_stride
273
+ + row_offsets[:, None] * attn_seq_stride
274
+ + k_offsets[None, :] * attn_seq2_stride
275
+ )
276
+ attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
277
+ attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
278
+
279
+ v_ptrs = (
280
+ V_ptr
281
+ + batch_id * v_batch_stride
282
+ + head_id * v_head_stride
283
+ + k_offsets[:, None] * v_seq_stride
284
+ + col_offsets[None, :] * v_dim_stride
285
+ )
286
+ v_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
287
+ v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
288
+
289
+ acc += tl.dot(attn_chunk, v_chunk)
290
+
291
+ out_ptrs = (
292
+ Out_ptr
293
+ + batch_id * out_batch_stride
294
+ + head_id * out_head_stride
295
+ + row_offsets[:, None] * out_seq_stride
296
+ + col_offsets[None, :] * out_dim_stride
297
+ )
298
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
299
+ tl.store(out_ptrs, acc, mask=valid_mask)
300
+
301
+
302
+ @triton.jit
303
+ def _fused_neighborhood_attention_grad_qk_kernel(
304
+ grad_attn_ptr,
305
+ K_ptr,
306
+ grad_Q_ptr,
307
+ grad_attn_batch_stride,
308
+ grad_attn_head_stride,
309
+ grad_attn_seq_stride,
310
+ grad_attn_seq2_stride,
311
+ k_batch_stride,
312
+ k_head_stride,
313
+ k_seq_stride,
314
+ k_dim_stride,
315
+ grad_q_batch_stride,
316
+ grad_q_head_stride,
317
+ grad_q_seq_stride,
318
+ grad_q_dim_stride,
319
+ batch_size: tl.constexpr,
320
+ num_heads: tl.constexpr,
321
+ seq_len: tl.constexpr,
322
+ head_dim: tl.constexpr,
323
+ scale: tl.constexpr,
324
+ BLOCK_SIZE_M: tl.constexpr,
325
+ BLOCK_SIZE_N: tl.constexpr,
326
+ BLOCK_SIZE_K: tl.constexpr,
327
+ num_stages: tl.constexpr,
328
+ num_warps: tl.constexpr,
329
+ ):
330
+ """
331
+ Compute gradient with respect to queries: grad_Q = grad_attn @ K * scale.
332
+
333
+ This kernel computes the gradient of the loss with respect to the query tensor
334
+ by multiplying the gradient of attention weights with the key tensor. The
335
+ computation follows the chain rule for the attention mechanism.
336
+
337
+ Args:
338
+ grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
339
+ K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
340
+ grad_Q_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
341
+ grad_attn_*_stride: Strides for gradient attention tensor
342
+ k_*_stride: Strides for key tensor
343
+ grad_q_*_stride: Strides for gradient query tensor
344
+ batch_size: Number of batches
345
+ num_heads: Number of attention heads
346
+ seq_len: Sequence length
347
+ head_dim: Dimension of each attention head
348
+ scale: Scaling factor applied to attention scores
349
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
350
+ BLOCK_SIZE_N: Block size for head dimension (cols)
351
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
352
+ num_stages: Number of pipeline stages
353
+ num_warps: Number of warps
354
+
355
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
356
+ Each program computes a tile of the query gradient matrix.
357
+ """
358
+ batch_head_id = tl.program_id(0)
359
+ tile_m = tl.program_id(1)
360
+ tile_n = tl.program_id(2)
361
+
362
+ batch_id = batch_head_id // num_heads
363
+ head_id = batch_head_id % num_heads
364
+
365
+ row_start = tile_m * BLOCK_SIZE_M
366
+ col_start = tile_n * BLOCK_SIZE_N
367
+
368
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
369
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
370
+
371
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
372
+
373
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
374
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
375
+ k_mask = k_offsets < seq_len
376
+
377
+ grad_attn_ptrs = (
378
+ grad_attn_ptr
379
+ + batch_id * grad_attn_batch_stride
380
+ + head_id * grad_attn_head_stride
381
+ + row_offsets[:, None] * grad_attn_seq_stride
382
+ + k_offsets[None, :] * grad_attn_seq2_stride
383
+ )
384
+ grad_attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
385
+ grad_attn_chunk = tl.load(grad_attn_ptrs, mask=grad_attn_mask, other=0.0)
386
+
387
+ k_ptrs = (
388
+ K_ptr
389
+ + batch_id * k_batch_stride
390
+ + head_id * k_head_stride
391
+ + k_offsets[:, None] * k_seq_stride
392
+ + col_offsets[None, :] * k_dim_stride
393
+ )
394
+ k_mask_2d = k_mask[:, None] & (col_offsets[None, :] < head_dim)
395
+ k_chunk = tl.load(k_ptrs, mask=k_mask_2d, other=0.0)
396
+
397
+ acc += tl.dot(grad_attn_chunk, k_chunk)
398
+
399
+ acc = acc * scale
400
+
401
+ grad_q_ptrs = (
402
+ grad_Q_ptr
403
+ + batch_id * grad_q_batch_stride
404
+ + head_id * grad_q_head_stride
405
+ + row_offsets[:, None] * grad_q_seq_stride
406
+ + col_offsets[None, :] * grad_q_dim_stride
407
+ )
408
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
409
+ tl.store(grad_q_ptrs, acc, mask=valid_mask)
410
+
411
+
412
+ @triton.jit
413
+ def _fused_neighborhood_attention_grad_k_kernel(
414
+ grad_attn_ptr,
415
+ Q_ptr,
416
+ grad_K_ptr,
417
+ grad_attn_batch_stride,
418
+ grad_attn_head_stride,
419
+ grad_attn_seq_stride,
420
+ grad_attn_seq2_stride,
421
+ q_batch_stride,
422
+ q_head_stride,
423
+ q_seq_stride,
424
+ q_dim_stride,
425
+ grad_k_batch_stride,
426
+ grad_k_head_stride,
427
+ grad_k_seq_stride,
428
+ grad_k_dim_stride,
429
+ batch_size: tl.constexpr,
430
+ num_heads: tl.constexpr,
431
+ seq_len: tl.constexpr,
432
+ head_dim: tl.constexpr,
433
+ scale: tl.constexpr,
434
+ BLOCK_SIZE_M: tl.constexpr,
435
+ BLOCK_SIZE_N: tl.constexpr,
436
+ BLOCK_SIZE_K: tl.constexpr,
437
+ num_stages: tl.constexpr,
438
+ num_warps: tl.constexpr,
439
+ ):
440
+ """
441
+ Compute gradient with respect to keys: grad_K = grad_attn^T @ Q * scale.
442
+
443
+ This kernel computes the gradient of the loss with respect to the key tensor
444
+ by multiplying the transpose of the gradient of attention weights with the
445
+ query tensor. The computation follows the chain rule for the attention mechanism.
446
+
447
+ Args:
448
+ grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
449
+ Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
450
+ grad_K_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
451
+ grad_attn_*_stride: Strides for gradient attention tensor
452
+ q_*_stride: Strides for query tensor
453
+ grad_k_*_stride: Strides for gradient key tensor
454
+ batch_size: Number of batches
455
+ num_heads: Number of attention heads
456
+ seq_len: Sequence length
457
+ head_dim: Dimension of each attention head
458
+ scale: Scaling factor applied to attention scores
459
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
460
+ BLOCK_SIZE_N: Block size for head dimension (cols)
461
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
462
+ num_stages: Number of pipeline stages
463
+ num_warps: Number of warps
464
+
465
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
466
+ Each program computes a tile of the key gradient matrix.
467
+ """
468
+ batch_head_id = tl.program_id(0)
469
+ tile_m = tl.program_id(1)
470
+ tile_n = tl.program_id(2)
471
+
472
+ batch_id = batch_head_id // num_heads
473
+ head_id = batch_head_id % num_heads
474
+
475
+ row_start = tile_m * BLOCK_SIZE_M
476
+ col_start = tile_n * BLOCK_SIZE_N
477
+
478
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
479
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
480
+
481
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
482
+
483
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
484
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
485
+ k_mask = k_offsets < seq_len
486
+
487
+ q_ptrs = (
488
+ Q_ptr
489
+ + batch_id * q_batch_stride
490
+ + head_id * q_head_stride
491
+ + k_offsets[:, None] * q_seq_stride
492
+ + col_offsets[None, :] * q_dim_stride
493
+ )
494
+ q_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
495
+ q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
496
+
497
+ grad_attn_T_ptrs = (
498
+ grad_attn_ptr
499
+ + batch_id * grad_attn_batch_stride
500
+ + head_id * grad_attn_head_stride
501
+ + row_offsets[:, None] * grad_attn_seq2_stride
502
+ + k_offsets[None, :] * grad_attn_seq_stride
503
+ )
504
+ grad_attn_T_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
505
+ grad_attn_T_chunk = tl.load(grad_attn_T_ptrs, mask=grad_attn_T_mask, other=0.0)
506
+
507
+ acc += tl.dot(grad_attn_T_chunk, q_chunk)
508
+
509
+ acc = acc * scale
510
+
511
+ grad_k_ptrs = (
512
+ grad_K_ptr
513
+ + batch_id * grad_k_batch_stride
514
+ + head_id * grad_k_head_stride
515
+ + row_offsets[:, None] * grad_k_seq_stride
516
+ + col_offsets[None, :] * grad_k_dim_stride
517
+ )
518
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
519
+ tl.store(grad_k_ptrs, acc, mask=valid_mask)
520
+
521
+
522
+ @triton.jit
523
+ def _fused_neighborhood_attention_grad_v_kernel(
524
+ Attn_ptr,
525
+ grad_output_ptr,
526
+ grad_V_ptr,
527
+ attn_batch_stride,
528
+ attn_head_stride,
529
+ attn_seq_stride,
530
+ attn_seq2_stride,
531
+ grad_out_batch_stride,
532
+ grad_out_head_stride,
533
+ grad_out_seq_stride,
534
+ grad_out_dim_stride,
535
+ grad_v_batch_stride,
536
+ grad_v_head_stride,
537
+ grad_v_seq_stride,
538
+ grad_v_dim_stride,
539
+ batch_size: tl.constexpr,
540
+ num_heads: tl.constexpr,
541
+ seq_len: tl.constexpr,
542
+ head_dim: tl.constexpr,
543
+ BLOCK_SIZE_M: tl.constexpr,
544
+ BLOCK_SIZE_N: tl.constexpr,
545
+ BLOCK_SIZE_K: tl.constexpr,
546
+ num_stages: tl.constexpr,
547
+ num_warps: tl.constexpr,
548
+ ):
549
+ """
550
+ Compute gradient with respect to values: grad_V = Attn^T @ grad_output.
551
+
552
+ This kernel computes the gradient of the loss with respect to the value tensor
553
+ by multiplying the transpose of the attention weights with the gradient of the
554
+ output. The computation follows the chain rule for the attention mechanism.
555
+
556
+ Args:
557
+ Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
558
+ grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
559
+ grad_V_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
560
+ attn_*_stride: Strides for attention weights tensor
561
+ grad_out_*_stride: Strides for gradient output tensor
562
+ grad_v_*_stride: Strides for gradient value tensor
563
+ batch_size: Number of batches
564
+ num_heads: Number of attention heads
565
+ seq_len: Sequence length
566
+ head_dim: Dimension of each attention head
567
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
568
+ BLOCK_SIZE_N: Block size for head dimension (cols)
569
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
570
+ num_stages: Number of pipeline stages
571
+ num_warps: Number of warps
572
+
573
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
574
+ Each program computes a tile of the value gradient matrix.
575
+ """
576
+ batch_head_id = tl.program_id(0)
577
+ tile_m = tl.program_id(1)
578
+ tile_n = tl.program_id(2)
579
+
580
+ batch_id = batch_head_id // num_heads
581
+ head_id = batch_head_id % num_heads
582
+
583
+ row_start = tile_m * BLOCK_SIZE_M
584
+ col_start = tile_n * BLOCK_SIZE_N
585
+
586
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
587
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
588
+
589
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
590
+
591
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
592
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
593
+ k_mask = k_offsets < seq_len
594
+
595
+ attn_ptrs = (
596
+ Attn_ptr
597
+ + batch_id * attn_batch_stride
598
+ + head_id * attn_head_stride
599
+ + k_offsets[:, None] * attn_seq_stride
600
+ + row_offsets[None, :] * attn_seq2_stride
601
+ )
602
+ attn_mask = k_mask[:, None] & (row_offsets[None, :] < seq_len)
603
+ attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
604
+
605
+ grad_out_ptrs = (
606
+ grad_output_ptr
607
+ + batch_id * grad_out_batch_stride
608
+ + head_id * grad_out_head_stride
609
+ + k_offsets[:, None] * grad_out_seq_stride
610
+ + col_offsets[None, :] * grad_out_dim_stride
611
+ )
612
+ grad_out_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
613
+ grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
614
+
615
+ acc += tl.dot(tl.trans(attn_chunk), grad_out_chunk)
616
+
617
+ grad_v_ptrs = (
618
+ grad_V_ptr
619
+ + batch_id * grad_v_batch_stride
620
+ + head_id * grad_v_head_stride
621
+ + row_offsets[:, None] * grad_v_seq_stride
622
+ + col_offsets[None, :] * grad_v_dim_stride
623
+ )
624
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
625
+ tl.store(grad_v_ptrs, acc, mask=valid_mask)
626
+
627
+
628
+ @triton.jit
629
+ def _fused_neighborhood_attention_grad_attn_kernel(
630
+ grad_output_ptr,
631
+ V_ptr,
632
+ grad_attn_ptr,
633
+ grad_out_batch_stride,
634
+ grad_out_head_stride,
635
+ grad_out_seq_stride,
636
+ grad_out_dim_stride,
637
+ v_batch_stride,
638
+ v_head_stride,
639
+ v_seq_stride,
640
+ v_dim_stride,
641
+ grad_attn_batch_stride,
642
+ grad_attn_head_stride,
643
+ grad_attn_seq_stride,
644
+ grad_attn_seq2_stride,
645
+ batch_size: tl.constexpr,
646
+ num_heads: tl.constexpr,
647
+ seq_len: tl.constexpr,
648
+ head_dim: tl.constexpr,
649
+ BLOCK_SIZE_M: tl.constexpr,
650
+ BLOCK_SIZE_N: tl.constexpr,
651
+ BLOCK_SIZE_K: tl.constexpr,
652
+ num_stages: tl.constexpr,
653
+ num_warps: tl.constexpr,
654
+ ):
655
+ """
656
+ Compute gradient with respect to attention weights: grad_attn = grad_output @ V^T.
657
+
658
+ This kernel computes the gradient of the loss with respect to the attention
659
+ weights by multiplying the gradient of the output with the transpose of the
660
+ value tensor. This gradient will later be passed through the softmax backward
661
+ pass to compute gradients for the attention scores.
662
+
663
+ Args:
664
+ grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
665
+ V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
666
+ grad_attn_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, seq_len]
667
+ grad_out_*_stride: Strides for gradient output tensor
668
+ v_*_stride: Strides for value tensor
669
+ grad_attn_*_stride: Strides for gradient attention tensor
670
+ batch_size: Number of batches
671
+ num_heads: Number of attention heads
672
+ seq_len: Sequence length
673
+ head_dim: Dimension of each attention head
674
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
675
+ BLOCK_SIZE_N: Block size for sequence dimension (cols)
676
+ BLOCK_SIZE_K: Block size for head dimension (reduction)
677
+ num_stages: Number of pipeline stages
678
+ num_warps: Number of warps
679
+
680
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
681
+ Each program computes a tile of the attention gradient matrix.
682
+ """
683
+ batch_head_id = tl.program_id(0)
684
+ tile_m = tl.program_id(1)
685
+ tile_n = tl.program_id(2)
686
+
687
+ batch_id = batch_head_id // num_heads
688
+ head_id = batch_head_id % num_heads
689
+
690
+ row_start = tile_m * BLOCK_SIZE_M
691
+ col_start = tile_n * BLOCK_SIZE_N
692
+
693
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
694
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
695
+
696
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
697
+
698
+ for k_start in range(0, head_dim, BLOCK_SIZE_K):
699
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
700
+ k_mask = k_offsets < head_dim
701
+
702
+ grad_out_ptrs = (
703
+ grad_output_ptr
704
+ + batch_id * grad_out_batch_stride
705
+ + head_id * grad_out_head_stride
706
+ + row_offsets[:, None] * grad_out_seq_stride
707
+ + k_offsets[None, :] * grad_out_dim_stride
708
+ )
709
+ grad_out_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
710
+ grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
711
+
712
+ v_ptrs = (
713
+ V_ptr
714
+ + batch_id * v_batch_stride
715
+ + head_id * v_head_stride
716
+ + col_offsets[None, :] * v_seq_stride
717
+ + k_offsets[:, None] * v_dim_stride
718
+ )
719
+ v_mask = (col_offsets[None, :] < seq_len) & k_mask[:, None]
720
+ v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
721
+
722
+ acc += tl.dot(grad_out_chunk, v_chunk)
723
+
724
+ grad_attn_ptrs = (
725
+ grad_attn_ptr
726
+ + batch_id * grad_attn_batch_stride
727
+ + head_id * grad_attn_head_stride
728
+ + row_offsets[:, None] * grad_attn_seq_stride
729
+ + col_offsets[None, :] * grad_attn_seq2_stride
730
+ )
731
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
732
+ tl.store(grad_attn_ptrs, acc, mask=valid_mask)
733
+
734
+
735
+ def fused_neighborhood_attention_forward(
736
+ query: torch.Tensor,
737
+ key: torch.Tensor,
738
+ value: torch.Tensor,
739
+ kernel_size: int = 7,
740
+ dilation: int = 1,
741
+ scale: float = None,
742
+ return_lse: bool = False,
743
+ ) -> tuple:
744
+ """
745
+ Fused neighborhood attention forward pass.
746
+
747
+ Args:
748
+ query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
749
+ key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
750
+ value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
751
+ kernel_size: Size of the neighborhood window
752
+ dilation: Dilation factor for the neighborhood
753
+ scale: Scaling factor for attention scores (default: rsqrt(head_dim))
754
+ return_lse: Whether to return log-sum-exp values
755
+
756
+ Returns:
757
+ Tuple of (output tensor, softmax parameters for backward)
758
+ """
759
+ batch_size, num_heads, seq_len, head_dim = query.shape
760
+
761
+ if scale is None:
762
+ scale = 1.0 / math.sqrt(head_dim)
763
+
764
+ query = query.contiguous()
765
+ key = key.contiguous()
766
+ value = value.contiguous()
767
+
768
+ output = torch.empty_like(query)
769
+ qk_scores = torch.empty(batch_size, num_heads, seq_len, seq_len, device=query.device, dtype=query.dtype)
770
+
771
+ mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32)
772
+
773
+ BLOCK_SIZE, num_warps = calculate_settings(seq_len)
774
+ BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
775
+ BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
776
+ BLOCK_SIZE_K = max(16, triton.next_power_of_2(head_dim))
777
+
778
+ num_stages = 4 if seq_len >= 512 else 2
779
+
780
+ grid_mask = (seq_len,)
781
+ _neighborhood_mask_kernel[grid_mask](
782
+ mask,
783
+ seq_len,
784
+ kernel_size,
785
+ dilation,
786
+ BLOCK_SIZE,
787
+ num_stages,
788
+ num_warps,
789
+ )
790
+
791
+ grid_qk = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len, BLOCK_SIZE_N))
792
+ _fused_neighborhood_attention_qk_kernel[grid_qk](
793
+ query,
794
+ key,
795
+ qk_scores,
796
+ mask,
797
+ query.stride(0),
798
+ query.stride(1),
799
+ query.stride(2),
800
+ query.stride(3),
801
+ key.stride(0),
802
+ key.stride(1),
803
+ key.stride(2),
804
+ key.stride(3),
805
+ qk_scores.stride(0),
806
+ qk_scores.stride(1),
807
+ qk_scores.stride(2),
808
+ qk_scores.stride(3),
809
+ batch_size,
810
+ num_heads,
811
+ seq_len,
812
+ head_dim,
813
+ scale,
814
+ kernel_size,
815
+ dilation,
816
+ BLOCK_SIZE_M,
817
+ BLOCK_SIZE_N,
818
+ BLOCK_SIZE_K,
819
+ num_stages,
820
+ num_warps,
821
+ )
822
+
823
+ qk_reshaped = qk_scores.view(batch_size * num_heads * seq_len, seq_len)
824
+ attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = _softmax_forward(qk_reshaped)
825
+ attn_weights = attn_reshaped.view(batch_size, num_heads, seq_len, seq_len)
826
+
827
+ grid_av = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
828
+ _fused_neighborhood_attention_av_kernel[grid_av](
829
+ attn_weights,
830
+ value,
831
+ output,
832
+ attn_weights.stride(0),
833
+ attn_weights.stride(1),
834
+ attn_weights.stride(2),
835
+ attn_weights.stride(3),
836
+ value.stride(0),
837
+ value.stride(1),
838
+ value.stride(2),
839
+ value.stride(3),
840
+ output.stride(0),
841
+ output.stride(1),
842
+ output.stride(2),
843
+ output.stride(3),
844
+ batch_size,
845
+ num_heads,
846
+ seq_len,
847
+ head_dim,
848
+ BLOCK_SIZE_M,
849
+ BLOCK_SIZE_N,
850
+ BLOCK_SIZE_K,
851
+ num_stages,
852
+ num_warps,
853
+ )
854
+
855
+ if return_lse:
856
+ raise NotImplementedError("return_lse=True is not supported yet.")
857
+
858
+ softmax_params = (BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch)
859
+ return output, attn_weights, softmax_params
860
+
861
+
862
+ class LigerFusedNeighborhoodAttentionFunction(torch.autograd.Function):
863
+ @staticmethod
864
+ @ensure_contiguous
865
+ def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None):
866
+ output, attn_weights, softmax_params = fused_neighborhood_attention_forward(
867
+ query, key, value, kernel_size, dilation, scale
868
+ )
869
+ ctx.save_for_backward(query, key, value, attn_weights)
870
+ ctx.kernel_size = kernel_size
871
+ ctx.dilation = dilation
872
+ ctx.scale = scale
873
+ ctx.softmax_params = softmax_params
874
+ return output
875
+
876
+ @staticmethod
877
+ @ensure_contiguous
878
+ def backward(ctx, grad_output):
879
+ query, key, value, attn_weights = ctx.saved_tensors
880
+ BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = ctx.softmax_params
881
+
882
+ batch_size, num_heads, seq_len, head_dim = query.shape
883
+ scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim)
884
+
885
+ grad_query = torch.zeros_like(query)
886
+ grad_key = torch.zeros_like(key)
887
+ grad_value = torch.zeros_like(value)
888
+ grad_attn_weights = torch.zeros_like(attn_weights)
889
+
890
+ BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
891
+ BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
892
+ BLOCK_SIZE_K = min(64, triton.next_power_of_2(head_dim))
893
+ num_stages = 4 if seq_len >= 512 else 2
894
+ _, num_warps = calculate_settings(seq_len)
895
+
896
+ grid_grad_attn = (
897
+ batch_size * num_heads,
898
+ triton.cdiv(seq_len, BLOCK_SIZE_M),
899
+ triton.cdiv(seq_len, BLOCK_SIZE_N),
900
+ )
901
+ _fused_neighborhood_attention_grad_attn_kernel[grid_grad_attn](
902
+ grad_output,
903
+ value,
904
+ grad_attn_weights,
905
+ grad_output.stride(0),
906
+ grad_output.stride(1),
907
+ grad_output.stride(2),
908
+ grad_output.stride(3),
909
+ value.stride(0),
910
+ value.stride(1),
911
+ value.stride(2),
912
+ value.stride(3),
913
+ grad_attn_weights.stride(0),
914
+ grad_attn_weights.stride(1),
915
+ grad_attn_weights.stride(2),
916
+ grad_attn_weights.stride(3),
917
+ batch_size,
918
+ num_heads,
919
+ seq_len,
920
+ head_dim,
921
+ BLOCK_SIZE_M,
922
+ BLOCK_SIZE_N,
923
+ BLOCK_SIZE_K,
924
+ num_stages,
925
+ num_warps,
926
+ )
927
+
928
+ grad_attn_reshaped = grad_attn_weights.view(batch_size * num_heads * seq_len, seq_len)
929
+ attn_reshaped = attn_weights.view(batch_size * num_heads * seq_len, seq_len)
930
+
931
+ grad_qk_reshaped = _softmax_backward(
932
+ grad_attn_reshaped, attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch
933
+ )
934
+ grad_qk_scores = grad_qk_reshaped.view(batch_size, num_heads, seq_len, seq_len)
935
+
936
+ grid_grad_q = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
937
+ _fused_neighborhood_attention_grad_qk_kernel[grid_grad_q](
938
+ grad_qk_scores,
939
+ key,
940
+ grad_query,
941
+ grad_qk_scores.stride(0),
942
+ grad_qk_scores.stride(1),
943
+ grad_qk_scores.stride(2),
944
+ grad_qk_scores.stride(3),
945
+ key.stride(0),
946
+ key.stride(1),
947
+ key.stride(2),
948
+ key.stride(3),
949
+ grad_query.stride(0),
950
+ grad_query.stride(1),
951
+ grad_query.stride(2),
952
+ grad_query.stride(3),
953
+ batch_size,
954
+ num_heads,
955
+ seq_len,
956
+ head_dim,
957
+ scale,
958
+ BLOCK_SIZE_M,
959
+ BLOCK_SIZE_N,
960
+ BLOCK_SIZE_K,
961
+ num_stages,
962
+ num_warps,
963
+ )
964
+
965
+ grid_grad_k = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
966
+ _fused_neighborhood_attention_grad_k_kernel[grid_grad_k](
967
+ grad_qk_scores,
968
+ query,
969
+ grad_key,
970
+ grad_qk_scores.stride(0),
971
+ grad_qk_scores.stride(1),
972
+ grad_qk_scores.stride(2),
973
+ grad_qk_scores.stride(3),
974
+ query.stride(0),
975
+ query.stride(1),
976
+ query.stride(2),
977
+ query.stride(3),
978
+ grad_key.stride(0),
979
+ grad_key.stride(1),
980
+ grad_key.stride(2),
981
+ grad_key.stride(3),
982
+ batch_size,
983
+ num_heads,
984
+ seq_len,
985
+ head_dim,
986
+ scale,
987
+ BLOCK_SIZE_M,
988
+ BLOCK_SIZE_N,
989
+ BLOCK_SIZE_K,
990
+ num_stages,
991
+ num_warps,
992
+ )
993
+
994
+ grid_grad_v = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
995
+ _fused_neighborhood_attention_grad_v_kernel[grid_grad_v](
996
+ attn_weights,
997
+ grad_output,
998
+ grad_value,
999
+ attn_weights.stride(0),
1000
+ attn_weights.stride(1),
1001
+ attn_weights.stride(2),
1002
+ attn_weights.stride(3),
1003
+ grad_output.stride(0),
1004
+ grad_output.stride(1),
1005
+ grad_output.stride(2),
1006
+ grad_output.stride(3),
1007
+ grad_value.stride(0),
1008
+ grad_value.stride(1),
1009
+ grad_value.stride(2),
1010
+ grad_value.stride(3),
1011
+ batch_size,
1012
+ num_heads,
1013
+ seq_len,
1014
+ head_dim,
1015
+ BLOCK_SIZE_M,
1016
+ BLOCK_SIZE_N,
1017
+ BLOCK_SIZE_K,
1018
+ num_stages,
1019
+ num_warps,
1020
+ )
1021
+
1022
+ return grad_query, grad_key, grad_value, None, None, None