liger-kernel-nightly 0.5.10.dev20250605223455__py3-none-any.whl → 0.5.10.dev20250605224739__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1022 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.softmax import _softmax_backward
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.utils import calculate_settings
10
+ from liger_kernel.ops.utils import ensure_contiguous
11
+
12
+
13
+ @triton.jit
14
+ def _neighborhood_mask_kernel(
15
+ mask_ptr,
16
+ seq_len: tl.constexpr,
17
+ kernel_size: tl.constexpr,
18
+ dilation: tl.constexpr,
19
+ BLOCK_SIZE: tl.constexpr,
20
+ num_stages: tl.constexpr,
21
+ num_warps: tl.constexpr,
22
+ ):
23
+ """
24
+ Generate a neighborhood attention mask for a given sequence.
25
+
26
+ This kernel creates a binary mask that defines which positions in a sequence
27
+ can attend to each other based on a neighborhood window with optional dilation.
28
+ Each row of the mask corresponds to a query position, and each column indicates
29
+ whether that key position is within the allowed neighborhood.
30
+
31
+ The neighborhood is defined as positions within kernel_size//2 * dilation distance
32
+ from the center position. When dilation > 1, only positions at multiples of the
33
+ dilation factor are included in the neighborhood.
34
+
35
+ Args:
36
+ mask_ptr: Pointer to the output mask tensor [seq_len, seq_len]
37
+ seq_len: Length of the input sequence
38
+ kernel_size: Size of the neighborhood window (must be odd)
39
+ dilation: Dilation factor for the neighborhood pattern
40
+ BLOCK_SIZE: Block size for processing (compile-time constant)
41
+ num_stages: Number of pipeline stages (compile-time constant)
42
+ num_warps: Number of warps (compile-time constant)
43
+
44
+ Grid: (seq_len,)
45
+ Each program processes one row of the mask matrix.
46
+ """
47
+ row_id = tl.program_id(0)
48
+
49
+ center = row_id
50
+ half_kernel = kernel_size // 2
51
+
52
+ start = tl.maximum(0, center - half_kernel * dilation)
53
+ end = tl.minimum(seq_len, center + half_kernel * dilation + 1)
54
+
55
+ col_offsets = tl.arange(0, BLOCK_SIZE)
56
+ mask = col_offsets < seq_len
57
+
58
+ valid_neighbors = (col_offsets >= start) & (col_offsets < end)
59
+ if dilation > 1:
60
+ relative_pos = col_offsets - center
61
+ valid_dilation = (relative_pos % dilation) == 0
62
+ valid_neighbors = valid_neighbors & valid_dilation
63
+
64
+ mask_values = tl.where(valid_neighbors & mask, 1.0, 0.0)
65
+
66
+ base_offset = row_id * seq_len
67
+ tl.store(mask_ptr + base_offset + col_offsets, mask_values, mask=mask)
68
+
69
+
70
+ @triton.jit
71
+ def _fused_neighborhood_attention_qk_kernel(
72
+ Q_ptr,
73
+ K_ptr,
74
+ QK_ptr,
75
+ mask_ptr,
76
+ q_batch_stride,
77
+ q_head_stride,
78
+ q_seq_stride,
79
+ q_dim_stride,
80
+ k_batch_stride,
81
+ k_head_stride,
82
+ k_seq_stride,
83
+ k_dim_stride,
84
+ qk_batch_stride,
85
+ qk_head_stride,
86
+ qk_seq_stride,
87
+ qk_seq2_stride,
88
+ batch_size: tl.constexpr,
89
+ num_heads: tl.constexpr,
90
+ seq_len: tl.constexpr,
91
+ head_dim: tl.constexpr,
92
+ scale: tl.constexpr,
93
+ kernel_size: tl.constexpr,
94
+ dilation: tl.constexpr,
95
+ BLOCK_SIZE_M: tl.constexpr,
96
+ BLOCK_SIZE_N: tl.constexpr,
97
+ BLOCK_SIZE_K: tl.constexpr,
98
+ num_stages: tl.constexpr,
99
+ num_warps: tl.constexpr,
100
+ ):
101
+ """
102
+ Compute Q @ K^T with neighborhood masking and scaling.
103
+
104
+ This kernel performs the first stage of neighborhood attention by computing
105
+ the attention scores between queries and keys, applying scaling, and masking
106
+ positions outside the neighborhood window. The result is a matrix of attention
107
+ scores ready for softmax normalization.
108
+
109
+ The computation is tiled across sequence dimensions for memory efficiency.
110
+ Each tile computes a block of the attention score matrix by iterating over
111
+ the head dimension and accumulating dot products.
112
+
113
+ Args:
114
+ Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
115
+ K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
116
+ QK_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, seq_len]
117
+ mask_ptr: Pointer to neighborhood mask [seq_len, seq_len]
118
+ q_*_stride: Strides for query tensor
119
+ k_*_stride: Strides for key tensor
120
+ qk_*_stride: Strides for output tensor
121
+ batch_size: Number of batches
122
+ num_heads: Number of attention heads
123
+ seq_len: Sequence length
124
+ head_dim: Dimension of each attention head
125
+ scale: Scaling factor for attention scores (typically 1/sqrt(head_dim))
126
+ kernel_size: Size of the neighborhood window
127
+ dilation: Dilation factor for the neighborhood
128
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
129
+ BLOCK_SIZE_N: Block size for sequence dimension (cols)
130
+ BLOCK_SIZE_K: Block size for head dimension
131
+ num_stages: Number of pipeline stages
132
+ num_warps: Number of warps
133
+
134
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
135
+ Each program computes a tile of the attention score matrix.
136
+ """
137
+ batch_head_id = tl.program_id(0)
138
+ tile_m = tl.program_id(1)
139
+ tile_n = tl.program_id(2)
140
+
141
+ batch_id = batch_head_id // num_heads
142
+ head_id = batch_head_id % num_heads
143
+
144
+ row_start = tile_m * BLOCK_SIZE_M
145
+ col_start = tile_n * BLOCK_SIZE_N
146
+
147
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
148
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
149
+
150
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
151
+
152
+ for k_start in range(0, head_dim, BLOCK_SIZE_K):
153
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
154
+ k_mask = k_offsets < head_dim
155
+
156
+ q_ptrs = (
157
+ Q_ptr
158
+ + batch_id * q_batch_stride
159
+ + head_id * q_head_stride
160
+ + row_offsets[:, None] * q_seq_stride
161
+ + k_offsets[None, :] * q_dim_stride
162
+ )
163
+ q_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
164
+ q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
165
+
166
+ k_ptrs = (
167
+ K_ptr
168
+ + batch_id * k_batch_stride
169
+ + head_id * k_head_stride
170
+ + col_offsets[:, None] * k_seq_stride
171
+ + k_offsets[None, :] * k_dim_stride
172
+ )
173
+ k_mask = (col_offsets[:, None] < seq_len) & k_mask[None, :]
174
+ k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0)
175
+
176
+ acc += tl.dot(q_chunk, tl.trans(k_chunk))
177
+
178
+ acc = acc * scale
179
+
180
+ mask_ptrs = mask_ptr + row_offsets[:, None] * seq_len + col_offsets[None, :]
181
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
182
+ neighborhood_mask = tl.load(mask_ptrs, mask=valid_mask, other=0.0)
183
+
184
+ acc = tl.where(neighborhood_mask > 0.0, acc, float("-inf"))
185
+
186
+ qk_ptrs = (
187
+ QK_ptr
188
+ + batch_id * qk_batch_stride
189
+ + head_id * qk_head_stride
190
+ + row_offsets[:, None] * qk_seq_stride
191
+ + col_offsets[None, :] * qk_seq2_stride
192
+ )
193
+ tl.store(qk_ptrs, acc, mask=valid_mask)
194
+
195
+
196
+ @triton.jit
197
+ def _fused_neighborhood_attention_av_kernel(
198
+ Attn_ptr,
199
+ V_ptr,
200
+ Out_ptr,
201
+ attn_batch_stride,
202
+ attn_head_stride,
203
+ attn_seq_stride,
204
+ attn_seq2_stride,
205
+ v_batch_stride,
206
+ v_head_stride,
207
+ v_seq_stride,
208
+ v_dim_stride,
209
+ out_batch_stride,
210
+ out_head_stride,
211
+ out_seq_stride,
212
+ out_dim_stride,
213
+ batch_size: tl.constexpr,
214
+ num_heads: tl.constexpr,
215
+ seq_len: tl.constexpr,
216
+ head_dim: tl.constexpr,
217
+ BLOCK_SIZE_M: tl.constexpr,
218
+ BLOCK_SIZE_N: tl.constexpr,
219
+ BLOCK_SIZE_K: tl.constexpr,
220
+ num_stages: tl.constexpr,
221
+ num_warps: tl.constexpr,
222
+ ):
223
+ """
224
+ Compute Attention @ V to produce the final output.
225
+
226
+ This kernel performs the second stage of neighborhood attention by multiplying
227
+ the normalized attention weights with the value matrix. The computation is
228
+ tiled for memory efficiency, with each tile computing a block of the output.
229
+
230
+ Args:
231
+ Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
232
+ V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
233
+ Out_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, head_dim]
234
+ attn_*_stride: Strides for attention weights tensor
235
+ v_*_stride: Strides for value tensor
236
+ out_*_stride: Strides for output tensor
237
+ batch_size: Number of batches
238
+ num_heads: Number of attention heads
239
+ seq_len: Sequence length
240
+ head_dim: Dimension of each attention head
241
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
242
+ BLOCK_SIZE_N: Block size for head dimension (cols)
243
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
244
+ num_stages: Number of pipeline stages
245
+ num_warps: Number of warps
246
+
247
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
248
+ Each program computes a tile of the output matrix.
249
+ """
250
+ batch_head_id = tl.program_id(0)
251
+ tile_m = tl.program_id(1)
252
+ tile_n = tl.program_id(2)
253
+
254
+ batch_id = batch_head_id // num_heads
255
+ head_id = batch_head_id % num_heads
256
+
257
+ row_start = tile_m * BLOCK_SIZE_M
258
+ col_start = tile_n * BLOCK_SIZE_N
259
+
260
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
261
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
262
+
263
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
264
+
265
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
266
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
267
+ k_mask = k_offsets < seq_len
268
+
269
+ attn_ptrs = (
270
+ Attn_ptr
271
+ + batch_id * attn_batch_stride
272
+ + head_id * attn_head_stride
273
+ + row_offsets[:, None] * attn_seq_stride
274
+ + k_offsets[None, :] * attn_seq2_stride
275
+ )
276
+ attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
277
+ attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
278
+
279
+ v_ptrs = (
280
+ V_ptr
281
+ + batch_id * v_batch_stride
282
+ + head_id * v_head_stride
283
+ + k_offsets[:, None] * v_seq_stride
284
+ + col_offsets[None, :] * v_dim_stride
285
+ )
286
+ v_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
287
+ v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
288
+
289
+ acc += tl.dot(attn_chunk, v_chunk)
290
+
291
+ out_ptrs = (
292
+ Out_ptr
293
+ + batch_id * out_batch_stride
294
+ + head_id * out_head_stride
295
+ + row_offsets[:, None] * out_seq_stride
296
+ + col_offsets[None, :] * out_dim_stride
297
+ )
298
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
299
+ tl.store(out_ptrs, acc, mask=valid_mask)
300
+
301
+
302
+ @triton.jit
303
+ def _fused_neighborhood_attention_grad_qk_kernel(
304
+ grad_attn_ptr,
305
+ K_ptr,
306
+ grad_Q_ptr,
307
+ grad_attn_batch_stride,
308
+ grad_attn_head_stride,
309
+ grad_attn_seq_stride,
310
+ grad_attn_seq2_stride,
311
+ k_batch_stride,
312
+ k_head_stride,
313
+ k_seq_stride,
314
+ k_dim_stride,
315
+ grad_q_batch_stride,
316
+ grad_q_head_stride,
317
+ grad_q_seq_stride,
318
+ grad_q_dim_stride,
319
+ batch_size: tl.constexpr,
320
+ num_heads: tl.constexpr,
321
+ seq_len: tl.constexpr,
322
+ head_dim: tl.constexpr,
323
+ scale: tl.constexpr,
324
+ BLOCK_SIZE_M: tl.constexpr,
325
+ BLOCK_SIZE_N: tl.constexpr,
326
+ BLOCK_SIZE_K: tl.constexpr,
327
+ num_stages: tl.constexpr,
328
+ num_warps: tl.constexpr,
329
+ ):
330
+ """
331
+ Compute gradient with respect to queries: grad_Q = grad_attn @ K * scale.
332
+
333
+ This kernel computes the gradient of the loss with respect to the query tensor
334
+ by multiplying the gradient of attention weights with the key tensor. The
335
+ computation follows the chain rule for the attention mechanism.
336
+
337
+ Args:
338
+ grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
339
+ K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
340
+ grad_Q_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
341
+ grad_attn_*_stride: Strides for gradient attention tensor
342
+ k_*_stride: Strides for key tensor
343
+ grad_q_*_stride: Strides for gradient query tensor
344
+ batch_size: Number of batches
345
+ num_heads: Number of attention heads
346
+ seq_len: Sequence length
347
+ head_dim: Dimension of each attention head
348
+ scale: Scaling factor applied to attention scores
349
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
350
+ BLOCK_SIZE_N: Block size for head dimension (cols)
351
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
352
+ num_stages: Number of pipeline stages
353
+ num_warps: Number of warps
354
+
355
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
356
+ Each program computes a tile of the query gradient matrix.
357
+ """
358
+ batch_head_id = tl.program_id(0)
359
+ tile_m = tl.program_id(1)
360
+ tile_n = tl.program_id(2)
361
+
362
+ batch_id = batch_head_id // num_heads
363
+ head_id = batch_head_id % num_heads
364
+
365
+ row_start = tile_m * BLOCK_SIZE_M
366
+ col_start = tile_n * BLOCK_SIZE_N
367
+
368
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
369
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
370
+
371
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
372
+
373
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
374
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
375
+ k_mask = k_offsets < seq_len
376
+
377
+ grad_attn_ptrs = (
378
+ grad_attn_ptr
379
+ + batch_id * grad_attn_batch_stride
380
+ + head_id * grad_attn_head_stride
381
+ + row_offsets[:, None] * grad_attn_seq_stride
382
+ + k_offsets[None, :] * grad_attn_seq2_stride
383
+ )
384
+ grad_attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
385
+ grad_attn_chunk = tl.load(grad_attn_ptrs, mask=grad_attn_mask, other=0.0)
386
+
387
+ k_ptrs = (
388
+ K_ptr
389
+ + batch_id * k_batch_stride
390
+ + head_id * k_head_stride
391
+ + k_offsets[:, None] * k_seq_stride
392
+ + col_offsets[None, :] * k_dim_stride
393
+ )
394
+ k_mask_2d = k_mask[:, None] & (col_offsets[None, :] < head_dim)
395
+ k_chunk = tl.load(k_ptrs, mask=k_mask_2d, other=0.0)
396
+
397
+ acc += tl.dot(grad_attn_chunk, k_chunk)
398
+
399
+ acc = acc * scale
400
+
401
+ grad_q_ptrs = (
402
+ grad_Q_ptr
403
+ + batch_id * grad_q_batch_stride
404
+ + head_id * grad_q_head_stride
405
+ + row_offsets[:, None] * grad_q_seq_stride
406
+ + col_offsets[None, :] * grad_q_dim_stride
407
+ )
408
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
409
+ tl.store(grad_q_ptrs, acc, mask=valid_mask)
410
+
411
+
412
+ @triton.jit
413
+ def _fused_neighborhood_attention_grad_k_kernel(
414
+ grad_attn_ptr,
415
+ Q_ptr,
416
+ grad_K_ptr,
417
+ grad_attn_batch_stride,
418
+ grad_attn_head_stride,
419
+ grad_attn_seq_stride,
420
+ grad_attn_seq2_stride,
421
+ q_batch_stride,
422
+ q_head_stride,
423
+ q_seq_stride,
424
+ q_dim_stride,
425
+ grad_k_batch_stride,
426
+ grad_k_head_stride,
427
+ grad_k_seq_stride,
428
+ grad_k_dim_stride,
429
+ batch_size: tl.constexpr,
430
+ num_heads: tl.constexpr,
431
+ seq_len: tl.constexpr,
432
+ head_dim: tl.constexpr,
433
+ scale: tl.constexpr,
434
+ BLOCK_SIZE_M: tl.constexpr,
435
+ BLOCK_SIZE_N: tl.constexpr,
436
+ BLOCK_SIZE_K: tl.constexpr,
437
+ num_stages: tl.constexpr,
438
+ num_warps: tl.constexpr,
439
+ ):
440
+ """
441
+ Compute gradient with respect to keys: grad_K = grad_attn^T @ Q * scale.
442
+
443
+ This kernel computes the gradient of the loss with respect to the key tensor
444
+ by multiplying the transpose of the gradient of attention weights with the
445
+ query tensor. The computation follows the chain rule for the attention mechanism.
446
+
447
+ Args:
448
+ grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
449
+ Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
450
+ grad_K_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
451
+ grad_attn_*_stride: Strides for gradient attention tensor
452
+ q_*_stride: Strides for query tensor
453
+ grad_k_*_stride: Strides for gradient key tensor
454
+ batch_size: Number of batches
455
+ num_heads: Number of attention heads
456
+ seq_len: Sequence length
457
+ head_dim: Dimension of each attention head
458
+ scale: Scaling factor applied to attention scores
459
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
460
+ BLOCK_SIZE_N: Block size for head dimension (cols)
461
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
462
+ num_stages: Number of pipeline stages
463
+ num_warps: Number of warps
464
+
465
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
466
+ Each program computes a tile of the key gradient matrix.
467
+ """
468
+ batch_head_id = tl.program_id(0)
469
+ tile_m = tl.program_id(1)
470
+ tile_n = tl.program_id(2)
471
+
472
+ batch_id = batch_head_id // num_heads
473
+ head_id = batch_head_id % num_heads
474
+
475
+ row_start = tile_m * BLOCK_SIZE_M
476
+ col_start = tile_n * BLOCK_SIZE_N
477
+
478
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
479
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
480
+
481
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
482
+
483
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
484
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
485
+ k_mask = k_offsets < seq_len
486
+
487
+ q_ptrs = (
488
+ Q_ptr
489
+ + batch_id * q_batch_stride
490
+ + head_id * q_head_stride
491
+ + k_offsets[:, None] * q_seq_stride
492
+ + col_offsets[None, :] * q_dim_stride
493
+ )
494
+ q_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
495
+ q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
496
+
497
+ grad_attn_T_ptrs = (
498
+ grad_attn_ptr
499
+ + batch_id * grad_attn_batch_stride
500
+ + head_id * grad_attn_head_stride
501
+ + row_offsets[:, None] * grad_attn_seq2_stride
502
+ + k_offsets[None, :] * grad_attn_seq_stride
503
+ )
504
+ grad_attn_T_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
505
+ grad_attn_T_chunk = tl.load(grad_attn_T_ptrs, mask=grad_attn_T_mask, other=0.0)
506
+
507
+ acc += tl.dot(grad_attn_T_chunk, q_chunk)
508
+
509
+ acc = acc * scale
510
+
511
+ grad_k_ptrs = (
512
+ grad_K_ptr
513
+ + batch_id * grad_k_batch_stride
514
+ + head_id * grad_k_head_stride
515
+ + row_offsets[:, None] * grad_k_seq_stride
516
+ + col_offsets[None, :] * grad_k_dim_stride
517
+ )
518
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
519
+ tl.store(grad_k_ptrs, acc, mask=valid_mask)
520
+
521
+
522
+ @triton.jit
523
+ def _fused_neighborhood_attention_grad_v_kernel(
524
+ Attn_ptr,
525
+ grad_output_ptr,
526
+ grad_V_ptr,
527
+ attn_batch_stride,
528
+ attn_head_stride,
529
+ attn_seq_stride,
530
+ attn_seq2_stride,
531
+ grad_out_batch_stride,
532
+ grad_out_head_stride,
533
+ grad_out_seq_stride,
534
+ grad_out_dim_stride,
535
+ grad_v_batch_stride,
536
+ grad_v_head_stride,
537
+ grad_v_seq_stride,
538
+ grad_v_dim_stride,
539
+ batch_size: tl.constexpr,
540
+ num_heads: tl.constexpr,
541
+ seq_len: tl.constexpr,
542
+ head_dim: tl.constexpr,
543
+ BLOCK_SIZE_M: tl.constexpr,
544
+ BLOCK_SIZE_N: tl.constexpr,
545
+ BLOCK_SIZE_K: tl.constexpr,
546
+ num_stages: tl.constexpr,
547
+ num_warps: tl.constexpr,
548
+ ):
549
+ """
550
+ Compute gradient with respect to values: grad_V = Attn^T @ grad_output.
551
+
552
+ This kernel computes the gradient of the loss with respect to the value tensor
553
+ by multiplying the transpose of the attention weights with the gradient of the
554
+ output. The computation follows the chain rule for the attention mechanism.
555
+
556
+ Args:
557
+ Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
558
+ grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
559
+ grad_V_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
560
+ attn_*_stride: Strides for attention weights tensor
561
+ grad_out_*_stride: Strides for gradient output tensor
562
+ grad_v_*_stride: Strides for gradient value tensor
563
+ batch_size: Number of batches
564
+ num_heads: Number of attention heads
565
+ seq_len: Sequence length
566
+ head_dim: Dimension of each attention head
567
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
568
+ BLOCK_SIZE_N: Block size for head dimension (cols)
569
+ BLOCK_SIZE_K: Block size for sequence dimension (reduction)
570
+ num_stages: Number of pipeline stages
571
+ num_warps: Number of warps
572
+
573
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
574
+ Each program computes a tile of the value gradient matrix.
575
+ """
576
+ batch_head_id = tl.program_id(0)
577
+ tile_m = tl.program_id(1)
578
+ tile_n = tl.program_id(2)
579
+
580
+ batch_id = batch_head_id // num_heads
581
+ head_id = batch_head_id % num_heads
582
+
583
+ row_start = tile_m * BLOCK_SIZE_M
584
+ col_start = tile_n * BLOCK_SIZE_N
585
+
586
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
587
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
588
+
589
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
590
+
591
+ for k_start in range(0, seq_len, BLOCK_SIZE_K):
592
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
593
+ k_mask = k_offsets < seq_len
594
+
595
+ attn_ptrs = (
596
+ Attn_ptr
597
+ + batch_id * attn_batch_stride
598
+ + head_id * attn_head_stride
599
+ + k_offsets[:, None] * attn_seq_stride
600
+ + row_offsets[None, :] * attn_seq2_stride
601
+ )
602
+ attn_mask = k_mask[:, None] & (row_offsets[None, :] < seq_len)
603
+ attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
604
+
605
+ grad_out_ptrs = (
606
+ grad_output_ptr
607
+ + batch_id * grad_out_batch_stride
608
+ + head_id * grad_out_head_stride
609
+ + k_offsets[:, None] * grad_out_seq_stride
610
+ + col_offsets[None, :] * grad_out_dim_stride
611
+ )
612
+ grad_out_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
613
+ grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
614
+
615
+ acc += tl.dot(tl.trans(attn_chunk), grad_out_chunk)
616
+
617
+ grad_v_ptrs = (
618
+ grad_V_ptr
619
+ + batch_id * grad_v_batch_stride
620
+ + head_id * grad_v_head_stride
621
+ + row_offsets[:, None] * grad_v_seq_stride
622
+ + col_offsets[None, :] * grad_v_dim_stride
623
+ )
624
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
625
+ tl.store(grad_v_ptrs, acc, mask=valid_mask)
626
+
627
+
628
+ @triton.jit
629
+ def _fused_neighborhood_attention_grad_attn_kernel(
630
+ grad_output_ptr,
631
+ V_ptr,
632
+ grad_attn_ptr,
633
+ grad_out_batch_stride,
634
+ grad_out_head_stride,
635
+ grad_out_seq_stride,
636
+ grad_out_dim_stride,
637
+ v_batch_stride,
638
+ v_head_stride,
639
+ v_seq_stride,
640
+ v_dim_stride,
641
+ grad_attn_batch_stride,
642
+ grad_attn_head_stride,
643
+ grad_attn_seq_stride,
644
+ grad_attn_seq2_stride,
645
+ batch_size: tl.constexpr,
646
+ num_heads: tl.constexpr,
647
+ seq_len: tl.constexpr,
648
+ head_dim: tl.constexpr,
649
+ BLOCK_SIZE_M: tl.constexpr,
650
+ BLOCK_SIZE_N: tl.constexpr,
651
+ BLOCK_SIZE_K: tl.constexpr,
652
+ num_stages: tl.constexpr,
653
+ num_warps: tl.constexpr,
654
+ ):
655
+ """
656
+ Compute gradient with respect to attention weights: grad_attn = grad_output @ V^T.
657
+
658
+ This kernel computes the gradient of the loss with respect to the attention
659
+ weights by multiplying the gradient of the output with the transpose of the
660
+ value tensor. This gradient will later be passed through the softmax backward
661
+ pass to compute gradients for the attention scores.
662
+
663
+ Args:
664
+ grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
665
+ V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
666
+ grad_attn_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, seq_len]
667
+ grad_out_*_stride: Strides for gradient output tensor
668
+ v_*_stride: Strides for value tensor
669
+ grad_attn_*_stride: Strides for gradient attention tensor
670
+ batch_size: Number of batches
671
+ num_heads: Number of attention heads
672
+ seq_len: Sequence length
673
+ head_dim: Dimension of each attention head
674
+ BLOCK_SIZE_M: Block size for sequence dimension (rows)
675
+ BLOCK_SIZE_N: Block size for sequence dimension (cols)
676
+ BLOCK_SIZE_K: Block size for head dimension (reduction)
677
+ num_stages: Number of pipeline stages
678
+ num_warps: Number of warps
679
+
680
+ Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
681
+ Each program computes a tile of the attention gradient matrix.
682
+ """
683
+ batch_head_id = tl.program_id(0)
684
+ tile_m = tl.program_id(1)
685
+ tile_n = tl.program_id(2)
686
+
687
+ batch_id = batch_head_id // num_heads
688
+ head_id = batch_head_id % num_heads
689
+
690
+ row_start = tile_m * BLOCK_SIZE_M
691
+ col_start = tile_n * BLOCK_SIZE_N
692
+
693
+ row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
694
+ col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
695
+
696
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
697
+
698
+ for k_start in range(0, head_dim, BLOCK_SIZE_K):
699
+ k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
700
+ k_mask = k_offsets < head_dim
701
+
702
+ grad_out_ptrs = (
703
+ grad_output_ptr
704
+ + batch_id * grad_out_batch_stride
705
+ + head_id * grad_out_head_stride
706
+ + row_offsets[:, None] * grad_out_seq_stride
707
+ + k_offsets[None, :] * grad_out_dim_stride
708
+ )
709
+ grad_out_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
710
+ grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
711
+
712
+ v_ptrs = (
713
+ V_ptr
714
+ + batch_id * v_batch_stride
715
+ + head_id * v_head_stride
716
+ + col_offsets[None, :] * v_seq_stride
717
+ + k_offsets[:, None] * v_dim_stride
718
+ )
719
+ v_mask = (col_offsets[None, :] < seq_len) & k_mask[:, None]
720
+ v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
721
+
722
+ acc += tl.dot(grad_out_chunk, v_chunk)
723
+
724
+ grad_attn_ptrs = (
725
+ grad_attn_ptr
726
+ + batch_id * grad_attn_batch_stride
727
+ + head_id * grad_attn_head_stride
728
+ + row_offsets[:, None] * grad_attn_seq_stride
729
+ + col_offsets[None, :] * grad_attn_seq2_stride
730
+ )
731
+ valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
732
+ tl.store(grad_attn_ptrs, acc, mask=valid_mask)
733
+
734
+
735
+ def fused_neighborhood_attention_forward(
736
+ query: torch.Tensor,
737
+ key: torch.Tensor,
738
+ value: torch.Tensor,
739
+ kernel_size: int = 7,
740
+ dilation: int = 1,
741
+ scale: float = None,
742
+ return_lse: bool = False,
743
+ ) -> tuple:
744
+ """
745
+ Fused neighborhood attention forward pass.
746
+
747
+ Args:
748
+ query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
749
+ key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
750
+ value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
751
+ kernel_size: Size of the neighborhood window
752
+ dilation: Dilation factor for the neighborhood
753
+ scale: Scaling factor for attention scores (default: rsqrt(head_dim))
754
+ return_lse: Whether to return log-sum-exp values
755
+
756
+ Returns:
757
+ Tuple of (output tensor, softmax parameters for backward)
758
+ """
759
+ batch_size, num_heads, seq_len, head_dim = query.shape
760
+
761
+ if scale is None:
762
+ scale = 1.0 / math.sqrt(head_dim)
763
+
764
+ query = query.contiguous()
765
+ key = key.contiguous()
766
+ value = value.contiguous()
767
+
768
+ output = torch.empty_like(query)
769
+ qk_scores = torch.empty(batch_size, num_heads, seq_len, seq_len, device=query.device, dtype=query.dtype)
770
+
771
+ mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32)
772
+
773
+ BLOCK_SIZE, num_warps = calculate_settings(seq_len)
774
+ BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
775
+ BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
776
+ BLOCK_SIZE_K = max(16, triton.next_power_of_2(head_dim))
777
+
778
+ num_stages = 4 if seq_len >= 512 else 2
779
+
780
+ grid_mask = (seq_len,)
781
+ _neighborhood_mask_kernel[grid_mask](
782
+ mask,
783
+ seq_len,
784
+ kernel_size,
785
+ dilation,
786
+ BLOCK_SIZE,
787
+ num_stages,
788
+ num_warps,
789
+ )
790
+
791
+ grid_qk = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len, BLOCK_SIZE_N))
792
+ _fused_neighborhood_attention_qk_kernel[grid_qk](
793
+ query,
794
+ key,
795
+ qk_scores,
796
+ mask,
797
+ query.stride(0),
798
+ query.stride(1),
799
+ query.stride(2),
800
+ query.stride(3),
801
+ key.stride(0),
802
+ key.stride(1),
803
+ key.stride(2),
804
+ key.stride(3),
805
+ qk_scores.stride(0),
806
+ qk_scores.stride(1),
807
+ qk_scores.stride(2),
808
+ qk_scores.stride(3),
809
+ batch_size,
810
+ num_heads,
811
+ seq_len,
812
+ head_dim,
813
+ scale,
814
+ kernel_size,
815
+ dilation,
816
+ BLOCK_SIZE_M,
817
+ BLOCK_SIZE_N,
818
+ BLOCK_SIZE_K,
819
+ num_stages,
820
+ num_warps,
821
+ )
822
+
823
+ qk_reshaped = qk_scores.view(batch_size * num_heads * seq_len, seq_len)
824
+ attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = _softmax_forward(qk_reshaped)
825
+ attn_weights = attn_reshaped.view(batch_size, num_heads, seq_len, seq_len)
826
+
827
+ grid_av = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
828
+ _fused_neighborhood_attention_av_kernel[grid_av](
829
+ attn_weights,
830
+ value,
831
+ output,
832
+ attn_weights.stride(0),
833
+ attn_weights.stride(1),
834
+ attn_weights.stride(2),
835
+ attn_weights.stride(3),
836
+ value.stride(0),
837
+ value.stride(1),
838
+ value.stride(2),
839
+ value.stride(3),
840
+ output.stride(0),
841
+ output.stride(1),
842
+ output.stride(2),
843
+ output.stride(3),
844
+ batch_size,
845
+ num_heads,
846
+ seq_len,
847
+ head_dim,
848
+ BLOCK_SIZE_M,
849
+ BLOCK_SIZE_N,
850
+ BLOCK_SIZE_K,
851
+ num_stages,
852
+ num_warps,
853
+ )
854
+
855
+ if return_lse:
856
+ raise NotImplementedError("return_lse=True is not supported yet.")
857
+
858
+ softmax_params = (BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch)
859
+ return output, attn_weights, softmax_params
860
+
861
+
862
+ class LigerFusedNeighborhoodAttentionFunction(torch.autograd.Function):
863
+ @staticmethod
864
+ @ensure_contiguous
865
+ def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None):
866
+ output, attn_weights, softmax_params = fused_neighborhood_attention_forward(
867
+ query, key, value, kernel_size, dilation, scale
868
+ )
869
+ ctx.save_for_backward(query, key, value, attn_weights)
870
+ ctx.kernel_size = kernel_size
871
+ ctx.dilation = dilation
872
+ ctx.scale = scale
873
+ ctx.softmax_params = softmax_params
874
+ return output
875
+
876
+ @staticmethod
877
+ @ensure_contiguous
878
+ def backward(ctx, grad_output):
879
+ query, key, value, attn_weights = ctx.saved_tensors
880
+ BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = ctx.softmax_params
881
+
882
+ batch_size, num_heads, seq_len, head_dim = query.shape
883
+ scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim)
884
+
885
+ grad_query = torch.zeros_like(query)
886
+ grad_key = torch.zeros_like(key)
887
+ grad_value = torch.zeros_like(value)
888
+ grad_attn_weights = torch.zeros_like(attn_weights)
889
+
890
+ BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
891
+ BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
892
+ BLOCK_SIZE_K = min(64, triton.next_power_of_2(head_dim))
893
+ num_stages = 4 if seq_len >= 512 else 2
894
+ _, num_warps = calculate_settings(seq_len)
895
+
896
+ grid_grad_attn = (
897
+ batch_size * num_heads,
898
+ triton.cdiv(seq_len, BLOCK_SIZE_M),
899
+ triton.cdiv(seq_len, BLOCK_SIZE_N),
900
+ )
901
+ _fused_neighborhood_attention_grad_attn_kernel[grid_grad_attn](
902
+ grad_output,
903
+ value,
904
+ grad_attn_weights,
905
+ grad_output.stride(0),
906
+ grad_output.stride(1),
907
+ grad_output.stride(2),
908
+ grad_output.stride(3),
909
+ value.stride(0),
910
+ value.stride(1),
911
+ value.stride(2),
912
+ value.stride(3),
913
+ grad_attn_weights.stride(0),
914
+ grad_attn_weights.stride(1),
915
+ grad_attn_weights.stride(2),
916
+ grad_attn_weights.stride(3),
917
+ batch_size,
918
+ num_heads,
919
+ seq_len,
920
+ head_dim,
921
+ BLOCK_SIZE_M,
922
+ BLOCK_SIZE_N,
923
+ BLOCK_SIZE_K,
924
+ num_stages,
925
+ num_warps,
926
+ )
927
+
928
+ grad_attn_reshaped = grad_attn_weights.view(batch_size * num_heads * seq_len, seq_len)
929
+ attn_reshaped = attn_weights.view(batch_size * num_heads * seq_len, seq_len)
930
+
931
+ grad_qk_reshaped = _softmax_backward(
932
+ grad_attn_reshaped, attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch
933
+ )
934
+ grad_qk_scores = grad_qk_reshaped.view(batch_size, num_heads, seq_len, seq_len)
935
+
936
+ grid_grad_q = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
937
+ _fused_neighborhood_attention_grad_qk_kernel[grid_grad_q](
938
+ grad_qk_scores,
939
+ key,
940
+ grad_query,
941
+ grad_qk_scores.stride(0),
942
+ grad_qk_scores.stride(1),
943
+ grad_qk_scores.stride(2),
944
+ grad_qk_scores.stride(3),
945
+ key.stride(0),
946
+ key.stride(1),
947
+ key.stride(2),
948
+ key.stride(3),
949
+ grad_query.stride(0),
950
+ grad_query.stride(1),
951
+ grad_query.stride(2),
952
+ grad_query.stride(3),
953
+ batch_size,
954
+ num_heads,
955
+ seq_len,
956
+ head_dim,
957
+ scale,
958
+ BLOCK_SIZE_M,
959
+ BLOCK_SIZE_N,
960
+ BLOCK_SIZE_K,
961
+ num_stages,
962
+ num_warps,
963
+ )
964
+
965
+ grid_grad_k = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
966
+ _fused_neighborhood_attention_grad_k_kernel[grid_grad_k](
967
+ grad_qk_scores,
968
+ query,
969
+ grad_key,
970
+ grad_qk_scores.stride(0),
971
+ grad_qk_scores.stride(1),
972
+ grad_qk_scores.stride(2),
973
+ grad_qk_scores.stride(3),
974
+ query.stride(0),
975
+ query.stride(1),
976
+ query.stride(2),
977
+ query.stride(3),
978
+ grad_key.stride(0),
979
+ grad_key.stride(1),
980
+ grad_key.stride(2),
981
+ grad_key.stride(3),
982
+ batch_size,
983
+ num_heads,
984
+ seq_len,
985
+ head_dim,
986
+ scale,
987
+ BLOCK_SIZE_M,
988
+ BLOCK_SIZE_N,
989
+ BLOCK_SIZE_K,
990
+ num_stages,
991
+ num_warps,
992
+ )
993
+
994
+ grid_grad_v = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
995
+ _fused_neighborhood_attention_grad_v_kernel[grid_grad_v](
996
+ attn_weights,
997
+ grad_output,
998
+ grad_value,
999
+ attn_weights.stride(0),
1000
+ attn_weights.stride(1),
1001
+ attn_weights.stride(2),
1002
+ attn_weights.stride(3),
1003
+ grad_output.stride(0),
1004
+ grad_output.stride(1),
1005
+ grad_output.stride(2),
1006
+ grad_output.stride(3),
1007
+ grad_value.stride(0),
1008
+ grad_value.stride(1),
1009
+ grad_value.stride(2),
1010
+ grad_value.stride(3),
1011
+ batch_size,
1012
+ num_heads,
1013
+ seq_len,
1014
+ head_dim,
1015
+ BLOCK_SIZE_M,
1016
+ BLOCK_SIZE_N,
1017
+ BLOCK_SIZE_K,
1018
+ num_stages,
1019
+ num_warps,
1020
+ )
1021
+
1022
+ return grad_query, grad_key, grad_value, None, None, None
@@ -4,6 +4,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
4
  from liger_kernel.ops.dyt import LigerDyTFunction
5
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
7
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
7
8
  from liger_kernel.ops.geglu import LigerGELUMulFunction
8
9
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
9
10
  from liger_kernel.ops.jsd import LigerJSDFunction
@@ -197,6 +198,33 @@ def liger_multi_token_attention(
197
198
  return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
198
199
 
199
200
 
201
+ def liger_fused_neighborhood_attention(
202
+ query,
203
+ key,
204
+ value,
205
+ kernel_size: int = 7,
206
+ dilation: int = 1,
207
+ scale: float = None,
208
+ ):
209
+ """
210
+ Liger fused neighborhood attention.
211
+
212
+ paper: https://arxiv.org/pdf/2504.16922
213
+
214
+ Args:
215
+ query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
216
+ key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
217
+ value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
218
+ kernel_size: Size of the neighborhood window (default: 7)
219
+ dilation: Dilation factor for the neighborhood (default: 1)
220
+ scale: Scaling factor for attention scores (default: rsqrt(head_dim))
221
+
222
+ Returns:
223
+ Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
224
+ """
225
+ return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
226
+
227
+
200
228
  def liger_tvd(
201
229
  input,
202
230
  target,
@@ -0,0 +1,234 @@
1
+ import math
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
9
+
10
+
11
+ class LigerFusedNeighborhoodAttention(nn.Module):
12
+ """
13
+ Liger Fused Neighborhood Attention Module.
14
+
15
+ Paper: https://arxiv.org/pdf/2504.16922
16
+
17
+ Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
18
+ around each position, reducing computational complexity from O(n²) to O(n*k)
19
+ where k is the neighborhood size.
20
+
21
+ Args:
22
+ hidden_size (int): The hidden dimension size
23
+ num_heads (int): Number of attention heads
24
+ kernel_size (int): Size of the neighborhood window (default: 7)
25
+ dilation (int): Dilation factor for the neighborhood (default: 1)
26
+ bias (bool): Whether to use bias in linear projections (default: True)
27
+ dropout (float): Dropout probability (default: 0.0)
28
+ scale (Optional[float]): Scaling factor for attention scores.
29
+ If None, uses 1/sqrt(head_dim) (default: None)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ num_heads: int,
36
+ kernel_size: int = 7,
37
+ dilation: int = 1,
38
+ bias: bool = True,
39
+ dropout: float = 0.0,
40
+ scale: Optional[float] = None,
41
+ ):
42
+ super().__init__()
43
+
44
+ if hidden_size % num_heads != 0:
45
+ raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
46
+
47
+ if kernel_size <= 0:
48
+ raise ValueError(f"kernel_size ({kernel_size}) must be positive")
49
+
50
+ if kernel_size % 2 == 0:
51
+ raise ValueError(f"kernel_size ({kernel_size}) must be odd")
52
+
53
+ if dilation < 1:
54
+ raise ValueError(f"dilation ({dilation}) must be positive")
55
+
56
+ self.hidden_size = hidden_size
57
+ self.num_heads = num_heads
58
+ self.head_dim = hidden_size // num_heads
59
+ self.kernel_size = kernel_size
60
+ self.dilation = dilation
61
+ self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
62
+ self.dropout_p = dropout
63
+
64
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
65
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
66
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
67
+
68
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
69
+
70
+ if dropout > 0.0:
71
+ self.dropout = nn.Dropout(dropout)
72
+ else:
73
+ self.dropout = None
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Forward pass of the fused neighborhood attention module.
82
+
83
+ Args:
84
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
85
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
86
+
87
+ Returns:
88
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
89
+ """
90
+ if attention_mask is not None:
91
+ raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
92
+
93
+ batch_size, seq_len, hidden_size = hidden_states.shape
94
+
95
+ query = self.q_proj(hidden_states)
96
+ key = self.k_proj(hidden_states)
97
+ value = self.v_proj(hidden_states)
98
+
99
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
100
+ key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+
103
+ attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
104
+ query, key, value, self.kernel_size, self.dilation, self.scale
105
+ )
106
+
107
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
108
+
109
+ if self.dropout is not None:
110
+ attn_output = self.dropout(attn_output)
111
+
112
+ output = self.out_proj(attn_output)
113
+
114
+ return output
115
+
116
+ def extra_repr(self) -> str:
117
+ return (
118
+ f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
119
+ f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
120
+ f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
121
+ )
122
+
123
+
124
+ class LigerFusedNeighborhoodAttentionLayer(nn.Module):
125
+ """
126
+ A complete neighborhood attention layer with layer norm and residual connection.
127
+
128
+ Args:
129
+ hidden_size (int): The hidden dimension size
130
+ num_heads (int): Number of attention heads
131
+ kernel_size (int): Size of the neighborhood window (default: 7)
132
+ dilation (int): Dilation factor for the neighborhood (default: 1)
133
+ bias (bool): Whether to use bias in linear projections (default: True)
134
+ dropout (float): Dropout probability (default: 0.0)
135
+ layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
136
+ scale (Optional[float]): Scaling factor for attention scores (default: None)
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ hidden_size: int,
142
+ num_heads: int,
143
+ kernel_size: int = 7,
144
+ dilation: int = 1,
145
+ bias: bool = True,
146
+ dropout: float = 0.0,
147
+ layer_norm_eps: float = 1e-5,
148
+ scale: Optional[float] = None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.attention = LigerFusedNeighborhoodAttention(
153
+ hidden_size=hidden_size,
154
+ num_heads=num_heads,
155
+ kernel_size=kernel_size,
156
+ dilation=dilation,
157
+ bias=bias,
158
+ dropout=dropout,
159
+ scale=scale,
160
+ )
161
+
162
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
163
+
164
+ if dropout > 0.0:
165
+ self.dropout = nn.Dropout(dropout)
166
+ else:
167
+ self.dropout = None
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ ) -> torch.Tensor:
174
+ """
175
+ Forward pass with residual connection and layer normalization.
176
+
177
+ Args:
178
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
179
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
180
+
181
+ Returns:
182
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
183
+ """
184
+ normed_hidden_states = self.layer_norm(hidden_states)
185
+
186
+ attn_output = self.attention(normed_hidden_states, attention_mask)
187
+
188
+ if self.dropout is not None:
189
+ attn_output = self.dropout(attn_output)
190
+
191
+ output = hidden_states + attn_output
192
+
193
+ return output
194
+
195
+
196
+ class LigerFusedNeighborhoodAttentionConfig:
197
+ """
198
+ Configuration class for Fused Neighborhood Attention.
199
+
200
+ This can be used to easily configure neighborhood attention parameters
201
+ for different model architectures.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size: int = 768,
207
+ num_heads: int = 12,
208
+ kernel_size: int = 7,
209
+ dilation: int = 1,
210
+ bias: bool = True,
211
+ dropout: float = 0.0,
212
+ layer_norm_eps: float = 1e-5,
213
+ scale: Optional[float] = None,
214
+ ):
215
+ self.hidden_size = hidden_size
216
+ self.num_heads = num_heads
217
+ self.kernel_size = kernel_size
218
+ self.dilation = dilation
219
+ self.bias = bias
220
+ self.dropout = dropout
221
+ self.layer_norm_eps = layer_norm_eps
222
+ self.scale = scale
223
+
224
+ def to_dict(self):
225
+ return {
226
+ "hidden_size": self.hidden_size,
227
+ "num_heads": self.num_heads,
228
+ "kernel_size": self.kernel_size,
229
+ "dilation": self.dilation,
230
+ "bias": self.bias,
231
+ "dropout": self.dropout,
232
+ "layer_norm_eps": self.layer_norm_eps,
233
+ "scale": self.scale,
234
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250605223455
3
+ Version: 0.5.10.dev20250605224739
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -20,6 +20,7 @@ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCR
20
20
  liger_kernel/ops/dyt.py,sha256=Y180EIvtUc2z83mhyub0EVOCQHJmWX3JnscqkOJqswk,5467
21
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
22
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
23
+ liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
23
24
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
25
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
25
26
  liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
@@ -42,9 +43,10 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
42
43
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
43
44
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
44
45
  liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
45
- liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANxfY68an7o8c,6444
46
+ liger_kernel/transformers/functional.py,sha256=7Emw7D6VPMg8hfasC33NiolvKmQVF1gV6VayKQCEWJM,7446
46
47
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
47
48
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
49
+ liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
48
50
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
49
51
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
50
52
  liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
@@ -85,9 +87,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
85
87
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
86
88
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
87
89
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
88
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
89
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/METADATA,sha256=jtKbBFfhtiyDQ7ZfpSZ1EwxGFNTYt0ND_4jL8Xr_pmc,24309
90
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
91
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
92
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
93
- liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/RECORD,,
90
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
91
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/METADATA,sha256=fJJZbkI2vH7QV5qhJouSk17zKPSUuZNWCWY2kXjDYPQ,24309
92
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
93
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
94
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
95
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/RECORD,,