liger-kernel-nightly 0.5.10.dev20250605223455__py3-none-any.whl → 0.5.10.dev20250606182408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/transformers/functional.py +28 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/RECORD +9 -7
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605223455.dist-info → liger_kernel_nightly-0.5.10.dev20250606182408.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1022 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from liger_kernel.ops.softmax import _softmax_backward
|
8
|
+
from liger_kernel.ops.softmax import _softmax_forward
|
9
|
+
from liger_kernel.ops.utils import calculate_settings
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
11
|
+
|
12
|
+
|
13
|
+
@triton.jit
|
14
|
+
def _neighborhood_mask_kernel(
|
15
|
+
mask_ptr,
|
16
|
+
seq_len: tl.constexpr,
|
17
|
+
kernel_size: tl.constexpr,
|
18
|
+
dilation: tl.constexpr,
|
19
|
+
BLOCK_SIZE: tl.constexpr,
|
20
|
+
num_stages: tl.constexpr,
|
21
|
+
num_warps: tl.constexpr,
|
22
|
+
):
|
23
|
+
"""
|
24
|
+
Generate a neighborhood attention mask for a given sequence.
|
25
|
+
|
26
|
+
This kernel creates a binary mask that defines which positions in a sequence
|
27
|
+
can attend to each other based on a neighborhood window with optional dilation.
|
28
|
+
Each row of the mask corresponds to a query position, and each column indicates
|
29
|
+
whether that key position is within the allowed neighborhood.
|
30
|
+
|
31
|
+
The neighborhood is defined as positions within kernel_size//2 * dilation distance
|
32
|
+
from the center position. When dilation > 1, only positions at multiples of the
|
33
|
+
dilation factor are included in the neighborhood.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
mask_ptr: Pointer to the output mask tensor [seq_len, seq_len]
|
37
|
+
seq_len: Length of the input sequence
|
38
|
+
kernel_size: Size of the neighborhood window (must be odd)
|
39
|
+
dilation: Dilation factor for the neighborhood pattern
|
40
|
+
BLOCK_SIZE: Block size for processing (compile-time constant)
|
41
|
+
num_stages: Number of pipeline stages (compile-time constant)
|
42
|
+
num_warps: Number of warps (compile-time constant)
|
43
|
+
|
44
|
+
Grid: (seq_len,)
|
45
|
+
Each program processes one row of the mask matrix.
|
46
|
+
"""
|
47
|
+
row_id = tl.program_id(0)
|
48
|
+
|
49
|
+
center = row_id
|
50
|
+
half_kernel = kernel_size // 2
|
51
|
+
|
52
|
+
start = tl.maximum(0, center - half_kernel * dilation)
|
53
|
+
end = tl.minimum(seq_len, center + half_kernel * dilation + 1)
|
54
|
+
|
55
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
56
|
+
mask = col_offsets < seq_len
|
57
|
+
|
58
|
+
valid_neighbors = (col_offsets >= start) & (col_offsets < end)
|
59
|
+
if dilation > 1:
|
60
|
+
relative_pos = col_offsets - center
|
61
|
+
valid_dilation = (relative_pos % dilation) == 0
|
62
|
+
valid_neighbors = valid_neighbors & valid_dilation
|
63
|
+
|
64
|
+
mask_values = tl.where(valid_neighbors & mask, 1.0, 0.0)
|
65
|
+
|
66
|
+
base_offset = row_id * seq_len
|
67
|
+
tl.store(mask_ptr + base_offset + col_offsets, mask_values, mask=mask)
|
68
|
+
|
69
|
+
|
70
|
+
@triton.jit
|
71
|
+
def _fused_neighborhood_attention_qk_kernel(
|
72
|
+
Q_ptr,
|
73
|
+
K_ptr,
|
74
|
+
QK_ptr,
|
75
|
+
mask_ptr,
|
76
|
+
q_batch_stride,
|
77
|
+
q_head_stride,
|
78
|
+
q_seq_stride,
|
79
|
+
q_dim_stride,
|
80
|
+
k_batch_stride,
|
81
|
+
k_head_stride,
|
82
|
+
k_seq_stride,
|
83
|
+
k_dim_stride,
|
84
|
+
qk_batch_stride,
|
85
|
+
qk_head_stride,
|
86
|
+
qk_seq_stride,
|
87
|
+
qk_seq2_stride,
|
88
|
+
batch_size: tl.constexpr,
|
89
|
+
num_heads: tl.constexpr,
|
90
|
+
seq_len: tl.constexpr,
|
91
|
+
head_dim: tl.constexpr,
|
92
|
+
scale: tl.constexpr,
|
93
|
+
kernel_size: tl.constexpr,
|
94
|
+
dilation: tl.constexpr,
|
95
|
+
BLOCK_SIZE_M: tl.constexpr,
|
96
|
+
BLOCK_SIZE_N: tl.constexpr,
|
97
|
+
BLOCK_SIZE_K: tl.constexpr,
|
98
|
+
num_stages: tl.constexpr,
|
99
|
+
num_warps: tl.constexpr,
|
100
|
+
):
|
101
|
+
"""
|
102
|
+
Compute Q @ K^T with neighborhood masking and scaling.
|
103
|
+
|
104
|
+
This kernel performs the first stage of neighborhood attention by computing
|
105
|
+
the attention scores between queries and keys, applying scaling, and masking
|
106
|
+
positions outside the neighborhood window. The result is a matrix of attention
|
107
|
+
scores ready for softmax normalization.
|
108
|
+
|
109
|
+
The computation is tiled across sequence dimensions for memory efficiency.
|
110
|
+
Each tile computes a block of the attention score matrix by iterating over
|
111
|
+
the head dimension and accumulating dot products.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
|
115
|
+
K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
|
116
|
+
QK_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, seq_len]
|
117
|
+
mask_ptr: Pointer to neighborhood mask [seq_len, seq_len]
|
118
|
+
q_*_stride: Strides for query tensor
|
119
|
+
k_*_stride: Strides for key tensor
|
120
|
+
qk_*_stride: Strides for output tensor
|
121
|
+
batch_size: Number of batches
|
122
|
+
num_heads: Number of attention heads
|
123
|
+
seq_len: Sequence length
|
124
|
+
head_dim: Dimension of each attention head
|
125
|
+
scale: Scaling factor for attention scores (typically 1/sqrt(head_dim))
|
126
|
+
kernel_size: Size of the neighborhood window
|
127
|
+
dilation: Dilation factor for the neighborhood
|
128
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
129
|
+
BLOCK_SIZE_N: Block size for sequence dimension (cols)
|
130
|
+
BLOCK_SIZE_K: Block size for head dimension
|
131
|
+
num_stages: Number of pipeline stages
|
132
|
+
num_warps: Number of warps
|
133
|
+
|
134
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
|
135
|
+
Each program computes a tile of the attention score matrix.
|
136
|
+
"""
|
137
|
+
batch_head_id = tl.program_id(0)
|
138
|
+
tile_m = tl.program_id(1)
|
139
|
+
tile_n = tl.program_id(2)
|
140
|
+
|
141
|
+
batch_id = batch_head_id // num_heads
|
142
|
+
head_id = batch_head_id % num_heads
|
143
|
+
|
144
|
+
row_start = tile_m * BLOCK_SIZE_M
|
145
|
+
col_start = tile_n * BLOCK_SIZE_N
|
146
|
+
|
147
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
148
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
149
|
+
|
150
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
151
|
+
|
152
|
+
for k_start in range(0, head_dim, BLOCK_SIZE_K):
|
153
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
154
|
+
k_mask = k_offsets < head_dim
|
155
|
+
|
156
|
+
q_ptrs = (
|
157
|
+
Q_ptr
|
158
|
+
+ batch_id * q_batch_stride
|
159
|
+
+ head_id * q_head_stride
|
160
|
+
+ row_offsets[:, None] * q_seq_stride
|
161
|
+
+ k_offsets[None, :] * q_dim_stride
|
162
|
+
)
|
163
|
+
q_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
|
164
|
+
q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
165
|
+
|
166
|
+
k_ptrs = (
|
167
|
+
K_ptr
|
168
|
+
+ batch_id * k_batch_stride
|
169
|
+
+ head_id * k_head_stride
|
170
|
+
+ col_offsets[:, None] * k_seq_stride
|
171
|
+
+ k_offsets[None, :] * k_dim_stride
|
172
|
+
)
|
173
|
+
k_mask = (col_offsets[:, None] < seq_len) & k_mask[None, :]
|
174
|
+
k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
175
|
+
|
176
|
+
acc += tl.dot(q_chunk, tl.trans(k_chunk))
|
177
|
+
|
178
|
+
acc = acc * scale
|
179
|
+
|
180
|
+
mask_ptrs = mask_ptr + row_offsets[:, None] * seq_len + col_offsets[None, :]
|
181
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
|
182
|
+
neighborhood_mask = tl.load(mask_ptrs, mask=valid_mask, other=0.0)
|
183
|
+
|
184
|
+
acc = tl.where(neighborhood_mask > 0.0, acc, float("-inf"))
|
185
|
+
|
186
|
+
qk_ptrs = (
|
187
|
+
QK_ptr
|
188
|
+
+ batch_id * qk_batch_stride
|
189
|
+
+ head_id * qk_head_stride
|
190
|
+
+ row_offsets[:, None] * qk_seq_stride
|
191
|
+
+ col_offsets[None, :] * qk_seq2_stride
|
192
|
+
)
|
193
|
+
tl.store(qk_ptrs, acc, mask=valid_mask)
|
194
|
+
|
195
|
+
|
196
|
+
@triton.jit
|
197
|
+
def _fused_neighborhood_attention_av_kernel(
|
198
|
+
Attn_ptr,
|
199
|
+
V_ptr,
|
200
|
+
Out_ptr,
|
201
|
+
attn_batch_stride,
|
202
|
+
attn_head_stride,
|
203
|
+
attn_seq_stride,
|
204
|
+
attn_seq2_stride,
|
205
|
+
v_batch_stride,
|
206
|
+
v_head_stride,
|
207
|
+
v_seq_stride,
|
208
|
+
v_dim_stride,
|
209
|
+
out_batch_stride,
|
210
|
+
out_head_stride,
|
211
|
+
out_seq_stride,
|
212
|
+
out_dim_stride,
|
213
|
+
batch_size: tl.constexpr,
|
214
|
+
num_heads: tl.constexpr,
|
215
|
+
seq_len: tl.constexpr,
|
216
|
+
head_dim: tl.constexpr,
|
217
|
+
BLOCK_SIZE_M: tl.constexpr,
|
218
|
+
BLOCK_SIZE_N: tl.constexpr,
|
219
|
+
BLOCK_SIZE_K: tl.constexpr,
|
220
|
+
num_stages: tl.constexpr,
|
221
|
+
num_warps: tl.constexpr,
|
222
|
+
):
|
223
|
+
"""
|
224
|
+
Compute Attention @ V to produce the final output.
|
225
|
+
|
226
|
+
This kernel performs the second stage of neighborhood attention by multiplying
|
227
|
+
the normalized attention weights with the value matrix. The computation is
|
228
|
+
tiled for memory efficiency, with each tile computing a block of the output.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
|
232
|
+
V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
|
233
|
+
Out_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, head_dim]
|
234
|
+
attn_*_stride: Strides for attention weights tensor
|
235
|
+
v_*_stride: Strides for value tensor
|
236
|
+
out_*_stride: Strides for output tensor
|
237
|
+
batch_size: Number of batches
|
238
|
+
num_heads: Number of attention heads
|
239
|
+
seq_len: Sequence length
|
240
|
+
head_dim: Dimension of each attention head
|
241
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
242
|
+
BLOCK_SIZE_N: Block size for head dimension (cols)
|
243
|
+
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
|
244
|
+
num_stages: Number of pipeline stages
|
245
|
+
num_warps: Number of warps
|
246
|
+
|
247
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
|
248
|
+
Each program computes a tile of the output matrix.
|
249
|
+
"""
|
250
|
+
batch_head_id = tl.program_id(0)
|
251
|
+
tile_m = tl.program_id(1)
|
252
|
+
tile_n = tl.program_id(2)
|
253
|
+
|
254
|
+
batch_id = batch_head_id // num_heads
|
255
|
+
head_id = batch_head_id % num_heads
|
256
|
+
|
257
|
+
row_start = tile_m * BLOCK_SIZE_M
|
258
|
+
col_start = tile_n * BLOCK_SIZE_N
|
259
|
+
|
260
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
261
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
262
|
+
|
263
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
264
|
+
|
265
|
+
for k_start in range(0, seq_len, BLOCK_SIZE_K):
|
266
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
267
|
+
k_mask = k_offsets < seq_len
|
268
|
+
|
269
|
+
attn_ptrs = (
|
270
|
+
Attn_ptr
|
271
|
+
+ batch_id * attn_batch_stride
|
272
|
+
+ head_id * attn_head_stride
|
273
|
+
+ row_offsets[:, None] * attn_seq_stride
|
274
|
+
+ k_offsets[None, :] * attn_seq2_stride
|
275
|
+
)
|
276
|
+
attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
|
277
|
+
attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
|
278
|
+
|
279
|
+
v_ptrs = (
|
280
|
+
V_ptr
|
281
|
+
+ batch_id * v_batch_stride
|
282
|
+
+ head_id * v_head_stride
|
283
|
+
+ k_offsets[:, None] * v_seq_stride
|
284
|
+
+ col_offsets[None, :] * v_dim_stride
|
285
|
+
)
|
286
|
+
v_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
|
287
|
+
v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
|
288
|
+
|
289
|
+
acc += tl.dot(attn_chunk, v_chunk)
|
290
|
+
|
291
|
+
out_ptrs = (
|
292
|
+
Out_ptr
|
293
|
+
+ batch_id * out_batch_stride
|
294
|
+
+ head_id * out_head_stride
|
295
|
+
+ row_offsets[:, None] * out_seq_stride
|
296
|
+
+ col_offsets[None, :] * out_dim_stride
|
297
|
+
)
|
298
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
|
299
|
+
tl.store(out_ptrs, acc, mask=valid_mask)
|
300
|
+
|
301
|
+
|
302
|
+
@triton.jit
|
303
|
+
def _fused_neighborhood_attention_grad_qk_kernel(
|
304
|
+
grad_attn_ptr,
|
305
|
+
K_ptr,
|
306
|
+
grad_Q_ptr,
|
307
|
+
grad_attn_batch_stride,
|
308
|
+
grad_attn_head_stride,
|
309
|
+
grad_attn_seq_stride,
|
310
|
+
grad_attn_seq2_stride,
|
311
|
+
k_batch_stride,
|
312
|
+
k_head_stride,
|
313
|
+
k_seq_stride,
|
314
|
+
k_dim_stride,
|
315
|
+
grad_q_batch_stride,
|
316
|
+
grad_q_head_stride,
|
317
|
+
grad_q_seq_stride,
|
318
|
+
grad_q_dim_stride,
|
319
|
+
batch_size: tl.constexpr,
|
320
|
+
num_heads: tl.constexpr,
|
321
|
+
seq_len: tl.constexpr,
|
322
|
+
head_dim: tl.constexpr,
|
323
|
+
scale: tl.constexpr,
|
324
|
+
BLOCK_SIZE_M: tl.constexpr,
|
325
|
+
BLOCK_SIZE_N: tl.constexpr,
|
326
|
+
BLOCK_SIZE_K: tl.constexpr,
|
327
|
+
num_stages: tl.constexpr,
|
328
|
+
num_warps: tl.constexpr,
|
329
|
+
):
|
330
|
+
"""
|
331
|
+
Compute gradient with respect to queries: grad_Q = grad_attn @ K * scale.
|
332
|
+
|
333
|
+
This kernel computes the gradient of the loss with respect to the query tensor
|
334
|
+
by multiplying the gradient of attention weights with the key tensor. The
|
335
|
+
computation follows the chain rule for the attention mechanism.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
|
339
|
+
K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
|
340
|
+
grad_Q_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
|
341
|
+
grad_attn_*_stride: Strides for gradient attention tensor
|
342
|
+
k_*_stride: Strides for key tensor
|
343
|
+
grad_q_*_stride: Strides for gradient query tensor
|
344
|
+
batch_size: Number of batches
|
345
|
+
num_heads: Number of attention heads
|
346
|
+
seq_len: Sequence length
|
347
|
+
head_dim: Dimension of each attention head
|
348
|
+
scale: Scaling factor applied to attention scores
|
349
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
350
|
+
BLOCK_SIZE_N: Block size for head dimension (cols)
|
351
|
+
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
|
352
|
+
num_stages: Number of pipeline stages
|
353
|
+
num_warps: Number of warps
|
354
|
+
|
355
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
|
356
|
+
Each program computes a tile of the query gradient matrix.
|
357
|
+
"""
|
358
|
+
batch_head_id = tl.program_id(0)
|
359
|
+
tile_m = tl.program_id(1)
|
360
|
+
tile_n = tl.program_id(2)
|
361
|
+
|
362
|
+
batch_id = batch_head_id // num_heads
|
363
|
+
head_id = batch_head_id % num_heads
|
364
|
+
|
365
|
+
row_start = tile_m * BLOCK_SIZE_M
|
366
|
+
col_start = tile_n * BLOCK_SIZE_N
|
367
|
+
|
368
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
369
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
370
|
+
|
371
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
372
|
+
|
373
|
+
for k_start in range(0, seq_len, BLOCK_SIZE_K):
|
374
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
375
|
+
k_mask = k_offsets < seq_len
|
376
|
+
|
377
|
+
grad_attn_ptrs = (
|
378
|
+
grad_attn_ptr
|
379
|
+
+ batch_id * grad_attn_batch_stride
|
380
|
+
+ head_id * grad_attn_head_stride
|
381
|
+
+ row_offsets[:, None] * grad_attn_seq_stride
|
382
|
+
+ k_offsets[None, :] * grad_attn_seq2_stride
|
383
|
+
)
|
384
|
+
grad_attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
|
385
|
+
grad_attn_chunk = tl.load(grad_attn_ptrs, mask=grad_attn_mask, other=0.0)
|
386
|
+
|
387
|
+
k_ptrs = (
|
388
|
+
K_ptr
|
389
|
+
+ batch_id * k_batch_stride
|
390
|
+
+ head_id * k_head_stride
|
391
|
+
+ k_offsets[:, None] * k_seq_stride
|
392
|
+
+ col_offsets[None, :] * k_dim_stride
|
393
|
+
)
|
394
|
+
k_mask_2d = k_mask[:, None] & (col_offsets[None, :] < head_dim)
|
395
|
+
k_chunk = tl.load(k_ptrs, mask=k_mask_2d, other=0.0)
|
396
|
+
|
397
|
+
acc += tl.dot(grad_attn_chunk, k_chunk)
|
398
|
+
|
399
|
+
acc = acc * scale
|
400
|
+
|
401
|
+
grad_q_ptrs = (
|
402
|
+
grad_Q_ptr
|
403
|
+
+ batch_id * grad_q_batch_stride
|
404
|
+
+ head_id * grad_q_head_stride
|
405
|
+
+ row_offsets[:, None] * grad_q_seq_stride
|
406
|
+
+ col_offsets[None, :] * grad_q_dim_stride
|
407
|
+
)
|
408
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
|
409
|
+
tl.store(grad_q_ptrs, acc, mask=valid_mask)
|
410
|
+
|
411
|
+
|
412
|
+
@triton.jit
|
413
|
+
def _fused_neighborhood_attention_grad_k_kernel(
|
414
|
+
grad_attn_ptr,
|
415
|
+
Q_ptr,
|
416
|
+
grad_K_ptr,
|
417
|
+
grad_attn_batch_stride,
|
418
|
+
grad_attn_head_stride,
|
419
|
+
grad_attn_seq_stride,
|
420
|
+
grad_attn_seq2_stride,
|
421
|
+
q_batch_stride,
|
422
|
+
q_head_stride,
|
423
|
+
q_seq_stride,
|
424
|
+
q_dim_stride,
|
425
|
+
grad_k_batch_stride,
|
426
|
+
grad_k_head_stride,
|
427
|
+
grad_k_seq_stride,
|
428
|
+
grad_k_dim_stride,
|
429
|
+
batch_size: tl.constexpr,
|
430
|
+
num_heads: tl.constexpr,
|
431
|
+
seq_len: tl.constexpr,
|
432
|
+
head_dim: tl.constexpr,
|
433
|
+
scale: tl.constexpr,
|
434
|
+
BLOCK_SIZE_M: tl.constexpr,
|
435
|
+
BLOCK_SIZE_N: tl.constexpr,
|
436
|
+
BLOCK_SIZE_K: tl.constexpr,
|
437
|
+
num_stages: tl.constexpr,
|
438
|
+
num_warps: tl.constexpr,
|
439
|
+
):
|
440
|
+
"""
|
441
|
+
Compute gradient with respect to keys: grad_K = grad_attn^T @ Q * scale.
|
442
|
+
|
443
|
+
This kernel computes the gradient of the loss with respect to the key tensor
|
444
|
+
by multiplying the transpose of the gradient of attention weights with the
|
445
|
+
query tensor. The computation follows the chain rule for the attention mechanism.
|
446
|
+
|
447
|
+
Args:
|
448
|
+
grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
|
449
|
+
Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
|
450
|
+
grad_K_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
|
451
|
+
grad_attn_*_stride: Strides for gradient attention tensor
|
452
|
+
q_*_stride: Strides for query tensor
|
453
|
+
grad_k_*_stride: Strides for gradient key tensor
|
454
|
+
batch_size: Number of batches
|
455
|
+
num_heads: Number of attention heads
|
456
|
+
seq_len: Sequence length
|
457
|
+
head_dim: Dimension of each attention head
|
458
|
+
scale: Scaling factor applied to attention scores
|
459
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
460
|
+
BLOCK_SIZE_N: Block size for head dimension (cols)
|
461
|
+
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
|
462
|
+
num_stages: Number of pipeline stages
|
463
|
+
num_warps: Number of warps
|
464
|
+
|
465
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
|
466
|
+
Each program computes a tile of the key gradient matrix.
|
467
|
+
"""
|
468
|
+
batch_head_id = tl.program_id(0)
|
469
|
+
tile_m = tl.program_id(1)
|
470
|
+
tile_n = tl.program_id(2)
|
471
|
+
|
472
|
+
batch_id = batch_head_id // num_heads
|
473
|
+
head_id = batch_head_id % num_heads
|
474
|
+
|
475
|
+
row_start = tile_m * BLOCK_SIZE_M
|
476
|
+
col_start = tile_n * BLOCK_SIZE_N
|
477
|
+
|
478
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
479
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
480
|
+
|
481
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
482
|
+
|
483
|
+
for k_start in range(0, seq_len, BLOCK_SIZE_K):
|
484
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
485
|
+
k_mask = k_offsets < seq_len
|
486
|
+
|
487
|
+
q_ptrs = (
|
488
|
+
Q_ptr
|
489
|
+
+ batch_id * q_batch_stride
|
490
|
+
+ head_id * q_head_stride
|
491
|
+
+ k_offsets[:, None] * q_seq_stride
|
492
|
+
+ col_offsets[None, :] * q_dim_stride
|
493
|
+
)
|
494
|
+
q_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
|
495
|
+
q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
496
|
+
|
497
|
+
grad_attn_T_ptrs = (
|
498
|
+
grad_attn_ptr
|
499
|
+
+ batch_id * grad_attn_batch_stride
|
500
|
+
+ head_id * grad_attn_head_stride
|
501
|
+
+ row_offsets[:, None] * grad_attn_seq2_stride
|
502
|
+
+ k_offsets[None, :] * grad_attn_seq_stride
|
503
|
+
)
|
504
|
+
grad_attn_T_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
|
505
|
+
grad_attn_T_chunk = tl.load(grad_attn_T_ptrs, mask=grad_attn_T_mask, other=0.0)
|
506
|
+
|
507
|
+
acc += tl.dot(grad_attn_T_chunk, q_chunk)
|
508
|
+
|
509
|
+
acc = acc * scale
|
510
|
+
|
511
|
+
grad_k_ptrs = (
|
512
|
+
grad_K_ptr
|
513
|
+
+ batch_id * grad_k_batch_stride
|
514
|
+
+ head_id * grad_k_head_stride
|
515
|
+
+ row_offsets[:, None] * grad_k_seq_stride
|
516
|
+
+ col_offsets[None, :] * grad_k_dim_stride
|
517
|
+
)
|
518
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
|
519
|
+
tl.store(grad_k_ptrs, acc, mask=valid_mask)
|
520
|
+
|
521
|
+
|
522
|
+
@triton.jit
|
523
|
+
def _fused_neighborhood_attention_grad_v_kernel(
|
524
|
+
Attn_ptr,
|
525
|
+
grad_output_ptr,
|
526
|
+
grad_V_ptr,
|
527
|
+
attn_batch_stride,
|
528
|
+
attn_head_stride,
|
529
|
+
attn_seq_stride,
|
530
|
+
attn_seq2_stride,
|
531
|
+
grad_out_batch_stride,
|
532
|
+
grad_out_head_stride,
|
533
|
+
grad_out_seq_stride,
|
534
|
+
grad_out_dim_stride,
|
535
|
+
grad_v_batch_stride,
|
536
|
+
grad_v_head_stride,
|
537
|
+
grad_v_seq_stride,
|
538
|
+
grad_v_dim_stride,
|
539
|
+
batch_size: tl.constexpr,
|
540
|
+
num_heads: tl.constexpr,
|
541
|
+
seq_len: tl.constexpr,
|
542
|
+
head_dim: tl.constexpr,
|
543
|
+
BLOCK_SIZE_M: tl.constexpr,
|
544
|
+
BLOCK_SIZE_N: tl.constexpr,
|
545
|
+
BLOCK_SIZE_K: tl.constexpr,
|
546
|
+
num_stages: tl.constexpr,
|
547
|
+
num_warps: tl.constexpr,
|
548
|
+
):
|
549
|
+
"""
|
550
|
+
Compute gradient with respect to values: grad_V = Attn^T @ grad_output.
|
551
|
+
|
552
|
+
This kernel computes the gradient of the loss with respect to the value tensor
|
553
|
+
by multiplying the transpose of the attention weights with the gradient of the
|
554
|
+
output. The computation follows the chain rule for the attention mechanism.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
|
558
|
+
grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
|
559
|
+
grad_V_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
|
560
|
+
attn_*_stride: Strides for attention weights tensor
|
561
|
+
grad_out_*_stride: Strides for gradient output tensor
|
562
|
+
grad_v_*_stride: Strides for gradient value tensor
|
563
|
+
batch_size: Number of batches
|
564
|
+
num_heads: Number of attention heads
|
565
|
+
seq_len: Sequence length
|
566
|
+
head_dim: Dimension of each attention head
|
567
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
568
|
+
BLOCK_SIZE_N: Block size for head dimension (cols)
|
569
|
+
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
|
570
|
+
num_stages: Number of pipeline stages
|
571
|
+
num_warps: Number of warps
|
572
|
+
|
573
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
|
574
|
+
Each program computes a tile of the value gradient matrix.
|
575
|
+
"""
|
576
|
+
batch_head_id = tl.program_id(0)
|
577
|
+
tile_m = tl.program_id(1)
|
578
|
+
tile_n = tl.program_id(2)
|
579
|
+
|
580
|
+
batch_id = batch_head_id // num_heads
|
581
|
+
head_id = batch_head_id % num_heads
|
582
|
+
|
583
|
+
row_start = tile_m * BLOCK_SIZE_M
|
584
|
+
col_start = tile_n * BLOCK_SIZE_N
|
585
|
+
|
586
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
587
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
588
|
+
|
589
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
590
|
+
|
591
|
+
for k_start in range(0, seq_len, BLOCK_SIZE_K):
|
592
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
593
|
+
k_mask = k_offsets < seq_len
|
594
|
+
|
595
|
+
attn_ptrs = (
|
596
|
+
Attn_ptr
|
597
|
+
+ batch_id * attn_batch_stride
|
598
|
+
+ head_id * attn_head_stride
|
599
|
+
+ k_offsets[:, None] * attn_seq_stride
|
600
|
+
+ row_offsets[None, :] * attn_seq2_stride
|
601
|
+
)
|
602
|
+
attn_mask = k_mask[:, None] & (row_offsets[None, :] < seq_len)
|
603
|
+
attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
|
604
|
+
|
605
|
+
grad_out_ptrs = (
|
606
|
+
grad_output_ptr
|
607
|
+
+ batch_id * grad_out_batch_stride
|
608
|
+
+ head_id * grad_out_head_stride
|
609
|
+
+ k_offsets[:, None] * grad_out_seq_stride
|
610
|
+
+ col_offsets[None, :] * grad_out_dim_stride
|
611
|
+
)
|
612
|
+
grad_out_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
|
613
|
+
grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
|
614
|
+
|
615
|
+
acc += tl.dot(tl.trans(attn_chunk), grad_out_chunk)
|
616
|
+
|
617
|
+
grad_v_ptrs = (
|
618
|
+
grad_V_ptr
|
619
|
+
+ batch_id * grad_v_batch_stride
|
620
|
+
+ head_id * grad_v_head_stride
|
621
|
+
+ row_offsets[:, None] * grad_v_seq_stride
|
622
|
+
+ col_offsets[None, :] * grad_v_dim_stride
|
623
|
+
)
|
624
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
|
625
|
+
tl.store(grad_v_ptrs, acc, mask=valid_mask)
|
626
|
+
|
627
|
+
|
628
|
+
@triton.jit
|
629
|
+
def _fused_neighborhood_attention_grad_attn_kernel(
|
630
|
+
grad_output_ptr,
|
631
|
+
V_ptr,
|
632
|
+
grad_attn_ptr,
|
633
|
+
grad_out_batch_stride,
|
634
|
+
grad_out_head_stride,
|
635
|
+
grad_out_seq_stride,
|
636
|
+
grad_out_dim_stride,
|
637
|
+
v_batch_stride,
|
638
|
+
v_head_stride,
|
639
|
+
v_seq_stride,
|
640
|
+
v_dim_stride,
|
641
|
+
grad_attn_batch_stride,
|
642
|
+
grad_attn_head_stride,
|
643
|
+
grad_attn_seq_stride,
|
644
|
+
grad_attn_seq2_stride,
|
645
|
+
batch_size: tl.constexpr,
|
646
|
+
num_heads: tl.constexpr,
|
647
|
+
seq_len: tl.constexpr,
|
648
|
+
head_dim: tl.constexpr,
|
649
|
+
BLOCK_SIZE_M: tl.constexpr,
|
650
|
+
BLOCK_SIZE_N: tl.constexpr,
|
651
|
+
BLOCK_SIZE_K: tl.constexpr,
|
652
|
+
num_stages: tl.constexpr,
|
653
|
+
num_warps: tl.constexpr,
|
654
|
+
):
|
655
|
+
"""
|
656
|
+
Compute gradient with respect to attention weights: grad_attn = grad_output @ V^T.
|
657
|
+
|
658
|
+
This kernel computes the gradient of the loss with respect to the attention
|
659
|
+
weights by multiplying the gradient of the output with the transpose of the
|
660
|
+
value tensor. This gradient will later be passed through the softmax backward
|
661
|
+
pass to compute gradients for the attention scores.
|
662
|
+
|
663
|
+
Args:
|
664
|
+
grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
|
665
|
+
V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
|
666
|
+
grad_attn_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, seq_len]
|
667
|
+
grad_out_*_stride: Strides for gradient output tensor
|
668
|
+
v_*_stride: Strides for value tensor
|
669
|
+
grad_attn_*_stride: Strides for gradient attention tensor
|
670
|
+
batch_size: Number of batches
|
671
|
+
num_heads: Number of attention heads
|
672
|
+
seq_len: Sequence length
|
673
|
+
head_dim: Dimension of each attention head
|
674
|
+
BLOCK_SIZE_M: Block size for sequence dimension (rows)
|
675
|
+
BLOCK_SIZE_N: Block size for sequence dimension (cols)
|
676
|
+
BLOCK_SIZE_K: Block size for head dimension (reduction)
|
677
|
+
num_stages: Number of pipeline stages
|
678
|
+
num_warps: Number of warps
|
679
|
+
|
680
|
+
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
|
681
|
+
Each program computes a tile of the attention gradient matrix.
|
682
|
+
"""
|
683
|
+
batch_head_id = tl.program_id(0)
|
684
|
+
tile_m = tl.program_id(1)
|
685
|
+
tile_n = tl.program_id(2)
|
686
|
+
|
687
|
+
batch_id = batch_head_id // num_heads
|
688
|
+
head_id = batch_head_id % num_heads
|
689
|
+
|
690
|
+
row_start = tile_m * BLOCK_SIZE_M
|
691
|
+
col_start = tile_n * BLOCK_SIZE_N
|
692
|
+
|
693
|
+
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
|
694
|
+
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
|
695
|
+
|
696
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
697
|
+
|
698
|
+
for k_start in range(0, head_dim, BLOCK_SIZE_K):
|
699
|
+
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
|
700
|
+
k_mask = k_offsets < head_dim
|
701
|
+
|
702
|
+
grad_out_ptrs = (
|
703
|
+
grad_output_ptr
|
704
|
+
+ batch_id * grad_out_batch_stride
|
705
|
+
+ head_id * grad_out_head_stride
|
706
|
+
+ row_offsets[:, None] * grad_out_seq_stride
|
707
|
+
+ k_offsets[None, :] * grad_out_dim_stride
|
708
|
+
)
|
709
|
+
grad_out_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
|
710
|
+
grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
|
711
|
+
|
712
|
+
v_ptrs = (
|
713
|
+
V_ptr
|
714
|
+
+ batch_id * v_batch_stride
|
715
|
+
+ head_id * v_head_stride
|
716
|
+
+ col_offsets[None, :] * v_seq_stride
|
717
|
+
+ k_offsets[:, None] * v_dim_stride
|
718
|
+
)
|
719
|
+
v_mask = (col_offsets[None, :] < seq_len) & k_mask[:, None]
|
720
|
+
v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
|
721
|
+
|
722
|
+
acc += tl.dot(grad_out_chunk, v_chunk)
|
723
|
+
|
724
|
+
grad_attn_ptrs = (
|
725
|
+
grad_attn_ptr
|
726
|
+
+ batch_id * grad_attn_batch_stride
|
727
|
+
+ head_id * grad_attn_head_stride
|
728
|
+
+ row_offsets[:, None] * grad_attn_seq_stride
|
729
|
+
+ col_offsets[None, :] * grad_attn_seq2_stride
|
730
|
+
)
|
731
|
+
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
|
732
|
+
tl.store(grad_attn_ptrs, acc, mask=valid_mask)
|
733
|
+
|
734
|
+
|
735
|
+
def fused_neighborhood_attention_forward(
|
736
|
+
query: torch.Tensor,
|
737
|
+
key: torch.Tensor,
|
738
|
+
value: torch.Tensor,
|
739
|
+
kernel_size: int = 7,
|
740
|
+
dilation: int = 1,
|
741
|
+
scale: float = None,
|
742
|
+
return_lse: bool = False,
|
743
|
+
) -> tuple:
|
744
|
+
"""
|
745
|
+
Fused neighborhood attention forward pass.
|
746
|
+
|
747
|
+
Args:
|
748
|
+
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
749
|
+
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
750
|
+
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
751
|
+
kernel_size: Size of the neighborhood window
|
752
|
+
dilation: Dilation factor for the neighborhood
|
753
|
+
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
|
754
|
+
return_lse: Whether to return log-sum-exp values
|
755
|
+
|
756
|
+
Returns:
|
757
|
+
Tuple of (output tensor, softmax parameters for backward)
|
758
|
+
"""
|
759
|
+
batch_size, num_heads, seq_len, head_dim = query.shape
|
760
|
+
|
761
|
+
if scale is None:
|
762
|
+
scale = 1.0 / math.sqrt(head_dim)
|
763
|
+
|
764
|
+
query = query.contiguous()
|
765
|
+
key = key.contiguous()
|
766
|
+
value = value.contiguous()
|
767
|
+
|
768
|
+
output = torch.empty_like(query)
|
769
|
+
qk_scores = torch.empty(batch_size, num_heads, seq_len, seq_len, device=query.device, dtype=query.dtype)
|
770
|
+
|
771
|
+
mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32)
|
772
|
+
|
773
|
+
BLOCK_SIZE, num_warps = calculate_settings(seq_len)
|
774
|
+
BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
|
775
|
+
BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
|
776
|
+
BLOCK_SIZE_K = max(16, triton.next_power_of_2(head_dim))
|
777
|
+
|
778
|
+
num_stages = 4 if seq_len >= 512 else 2
|
779
|
+
|
780
|
+
grid_mask = (seq_len,)
|
781
|
+
_neighborhood_mask_kernel[grid_mask](
|
782
|
+
mask,
|
783
|
+
seq_len,
|
784
|
+
kernel_size,
|
785
|
+
dilation,
|
786
|
+
BLOCK_SIZE,
|
787
|
+
num_stages,
|
788
|
+
num_warps,
|
789
|
+
)
|
790
|
+
|
791
|
+
grid_qk = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len, BLOCK_SIZE_N))
|
792
|
+
_fused_neighborhood_attention_qk_kernel[grid_qk](
|
793
|
+
query,
|
794
|
+
key,
|
795
|
+
qk_scores,
|
796
|
+
mask,
|
797
|
+
query.stride(0),
|
798
|
+
query.stride(1),
|
799
|
+
query.stride(2),
|
800
|
+
query.stride(3),
|
801
|
+
key.stride(0),
|
802
|
+
key.stride(1),
|
803
|
+
key.stride(2),
|
804
|
+
key.stride(3),
|
805
|
+
qk_scores.stride(0),
|
806
|
+
qk_scores.stride(1),
|
807
|
+
qk_scores.stride(2),
|
808
|
+
qk_scores.stride(3),
|
809
|
+
batch_size,
|
810
|
+
num_heads,
|
811
|
+
seq_len,
|
812
|
+
head_dim,
|
813
|
+
scale,
|
814
|
+
kernel_size,
|
815
|
+
dilation,
|
816
|
+
BLOCK_SIZE_M,
|
817
|
+
BLOCK_SIZE_N,
|
818
|
+
BLOCK_SIZE_K,
|
819
|
+
num_stages,
|
820
|
+
num_warps,
|
821
|
+
)
|
822
|
+
|
823
|
+
qk_reshaped = qk_scores.view(batch_size * num_heads * seq_len, seq_len)
|
824
|
+
attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = _softmax_forward(qk_reshaped)
|
825
|
+
attn_weights = attn_reshaped.view(batch_size, num_heads, seq_len, seq_len)
|
826
|
+
|
827
|
+
grid_av = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
|
828
|
+
_fused_neighborhood_attention_av_kernel[grid_av](
|
829
|
+
attn_weights,
|
830
|
+
value,
|
831
|
+
output,
|
832
|
+
attn_weights.stride(0),
|
833
|
+
attn_weights.stride(1),
|
834
|
+
attn_weights.stride(2),
|
835
|
+
attn_weights.stride(3),
|
836
|
+
value.stride(0),
|
837
|
+
value.stride(1),
|
838
|
+
value.stride(2),
|
839
|
+
value.stride(3),
|
840
|
+
output.stride(0),
|
841
|
+
output.stride(1),
|
842
|
+
output.stride(2),
|
843
|
+
output.stride(3),
|
844
|
+
batch_size,
|
845
|
+
num_heads,
|
846
|
+
seq_len,
|
847
|
+
head_dim,
|
848
|
+
BLOCK_SIZE_M,
|
849
|
+
BLOCK_SIZE_N,
|
850
|
+
BLOCK_SIZE_K,
|
851
|
+
num_stages,
|
852
|
+
num_warps,
|
853
|
+
)
|
854
|
+
|
855
|
+
if return_lse:
|
856
|
+
raise NotImplementedError("return_lse=True is not supported yet.")
|
857
|
+
|
858
|
+
softmax_params = (BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch)
|
859
|
+
return output, attn_weights, softmax_params
|
860
|
+
|
861
|
+
|
862
|
+
class LigerFusedNeighborhoodAttentionFunction(torch.autograd.Function):
|
863
|
+
@staticmethod
|
864
|
+
@ensure_contiguous
|
865
|
+
def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None):
|
866
|
+
output, attn_weights, softmax_params = fused_neighborhood_attention_forward(
|
867
|
+
query, key, value, kernel_size, dilation, scale
|
868
|
+
)
|
869
|
+
ctx.save_for_backward(query, key, value, attn_weights)
|
870
|
+
ctx.kernel_size = kernel_size
|
871
|
+
ctx.dilation = dilation
|
872
|
+
ctx.scale = scale
|
873
|
+
ctx.softmax_params = softmax_params
|
874
|
+
return output
|
875
|
+
|
876
|
+
@staticmethod
|
877
|
+
@ensure_contiguous
|
878
|
+
def backward(ctx, grad_output):
|
879
|
+
query, key, value, attn_weights = ctx.saved_tensors
|
880
|
+
BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = ctx.softmax_params
|
881
|
+
|
882
|
+
batch_size, num_heads, seq_len, head_dim = query.shape
|
883
|
+
scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim)
|
884
|
+
|
885
|
+
grad_query = torch.zeros_like(query)
|
886
|
+
grad_key = torch.zeros_like(key)
|
887
|
+
grad_value = torch.zeros_like(value)
|
888
|
+
grad_attn_weights = torch.zeros_like(attn_weights)
|
889
|
+
|
890
|
+
BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
|
891
|
+
BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
|
892
|
+
BLOCK_SIZE_K = min(64, triton.next_power_of_2(head_dim))
|
893
|
+
num_stages = 4 if seq_len >= 512 else 2
|
894
|
+
_, num_warps = calculate_settings(seq_len)
|
895
|
+
|
896
|
+
grid_grad_attn = (
|
897
|
+
batch_size * num_heads,
|
898
|
+
triton.cdiv(seq_len, BLOCK_SIZE_M),
|
899
|
+
triton.cdiv(seq_len, BLOCK_SIZE_N),
|
900
|
+
)
|
901
|
+
_fused_neighborhood_attention_grad_attn_kernel[grid_grad_attn](
|
902
|
+
grad_output,
|
903
|
+
value,
|
904
|
+
grad_attn_weights,
|
905
|
+
grad_output.stride(0),
|
906
|
+
grad_output.stride(1),
|
907
|
+
grad_output.stride(2),
|
908
|
+
grad_output.stride(3),
|
909
|
+
value.stride(0),
|
910
|
+
value.stride(1),
|
911
|
+
value.stride(2),
|
912
|
+
value.stride(3),
|
913
|
+
grad_attn_weights.stride(0),
|
914
|
+
grad_attn_weights.stride(1),
|
915
|
+
grad_attn_weights.stride(2),
|
916
|
+
grad_attn_weights.stride(3),
|
917
|
+
batch_size,
|
918
|
+
num_heads,
|
919
|
+
seq_len,
|
920
|
+
head_dim,
|
921
|
+
BLOCK_SIZE_M,
|
922
|
+
BLOCK_SIZE_N,
|
923
|
+
BLOCK_SIZE_K,
|
924
|
+
num_stages,
|
925
|
+
num_warps,
|
926
|
+
)
|
927
|
+
|
928
|
+
grad_attn_reshaped = grad_attn_weights.view(batch_size * num_heads * seq_len, seq_len)
|
929
|
+
attn_reshaped = attn_weights.view(batch_size * num_heads * seq_len, seq_len)
|
930
|
+
|
931
|
+
grad_qk_reshaped = _softmax_backward(
|
932
|
+
grad_attn_reshaped, attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch
|
933
|
+
)
|
934
|
+
grad_qk_scores = grad_qk_reshaped.view(batch_size, num_heads, seq_len, seq_len)
|
935
|
+
|
936
|
+
grid_grad_q = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
|
937
|
+
_fused_neighborhood_attention_grad_qk_kernel[grid_grad_q](
|
938
|
+
grad_qk_scores,
|
939
|
+
key,
|
940
|
+
grad_query,
|
941
|
+
grad_qk_scores.stride(0),
|
942
|
+
grad_qk_scores.stride(1),
|
943
|
+
grad_qk_scores.stride(2),
|
944
|
+
grad_qk_scores.stride(3),
|
945
|
+
key.stride(0),
|
946
|
+
key.stride(1),
|
947
|
+
key.stride(2),
|
948
|
+
key.stride(3),
|
949
|
+
grad_query.stride(0),
|
950
|
+
grad_query.stride(1),
|
951
|
+
grad_query.stride(2),
|
952
|
+
grad_query.stride(3),
|
953
|
+
batch_size,
|
954
|
+
num_heads,
|
955
|
+
seq_len,
|
956
|
+
head_dim,
|
957
|
+
scale,
|
958
|
+
BLOCK_SIZE_M,
|
959
|
+
BLOCK_SIZE_N,
|
960
|
+
BLOCK_SIZE_K,
|
961
|
+
num_stages,
|
962
|
+
num_warps,
|
963
|
+
)
|
964
|
+
|
965
|
+
grid_grad_k = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
|
966
|
+
_fused_neighborhood_attention_grad_k_kernel[grid_grad_k](
|
967
|
+
grad_qk_scores,
|
968
|
+
query,
|
969
|
+
grad_key,
|
970
|
+
grad_qk_scores.stride(0),
|
971
|
+
grad_qk_scores.stride(1),
|
972
|
+
grad_qk_scores.stride(2),
|
973
|
+
grad_qk_scores.stride(3),
|
974
|
+
query.stride(0),
|
975
|
+
query.stride(1),
|
976
|
+
query.stride(2),
|
977
|
+
query.stride(3),
|
978
|
+
grad_key.stride(0),
|
979
|
+
grad_key.stride(1),
|
980
|
+
grad_key.stride(2),
|
981
|
+
grad_key.stride(3),
|
982
|
+
batch_size,
|
983
|
+
num_heads,
|
984
|
+
seq_len,
|
985
|
+
head_dim,
|
986
|
+
scale,
|
987
|
+
BLOCK_SIZE_M,
|
988
|
+
BLOCK_SIZE_N,
|
989
|
+
BLOCK_SIZE_K,
|
990
|
+
num_stages,
|
991
|
+
num_warps,
|
992
|
+
)
|
993
|
+
|
994
|
+
grid_grad_v = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
|
995
|
+
_fused_neighborhood_attention_grad_v_kernel[grid_grad_v](
|
996
|
+
attn_weights,
|
997
|
+
grad_output,
|
998
|
+
grad_value,
|
999
|
+
attn_weights.stride(0),
|
1000
|
+
attn_weights.stride(1),
|
1001
|
+
attn_weights.stride(2),
|
1002
|
+
attn_weights.stride(3),
|
1003
|
+
grad_output.stride(0),
|
1004
|
+
grad_output.stride(1),
|
1005
|
+
grad_output.stride(2),
|
1006
|
+
grad_output.stride(3),
|
1007
|
+
grad_value.stride(0),
|
1008
|
+
grad_value.stride(1),
|
1009
|
+
grad_value.stride(2),
|
1010
|
+
grad_value.stride(3),
|
1011
|
+
batch_size,
|
1012
|
+
num_heads,
|
1013
|
+
seq_len,
|
1014
|
+
head_dim,
|
1015
|
+
BLOCK_SIZE_M,
|
1016
|
+
BLOCK_SIZE_N,
|
1017
|
+
BLOCK_SIZE_K,
|
1018
|
+
num_stages,
|
1019
|
+
num_warps,
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
return grad_query, grad_key, grad_value, None, None, None
|
@@ -4,6 +4,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
4
|
from liger_kernel.ops.dyt import LigerDyTFunction
|
5
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
6
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
7
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
7
8
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
8
9
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
9
10
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
@@ -197,6 +198,33 @@ def liger_multi_token_attention(
|
|
197
198
|
return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
|
198
199
|
|
199
200
|
|
201
|
+
def liger_fused_neighborhood_attention(
|
202
|
+
query,
|
203
|
+
key,
|
204
|
+
value,
|
205
|
+
kernel_size: int = 7,
|
206
|
+
dilation: int = 1,
|
207
|
+
scale: float = None,
|
208
|
+
):
|
209
|
+
"""
|
210
|
+
Liger fused neighborhood attention.
|
211
|
+
|
212
|
+
paper: https://arxiv.org/pdf/2504.16922
|
213
|
+
|
214
|
+
Args:
|
215
|
+
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
216
|
+
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
217
|
+
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
218
|
+
kernel_size: Size of the neighborhood window (default: 7)
|
219
|
+
dilation: Dilation factor for the neighborhood (default: 1)
|
220
|
+
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
224
|
+
"""
|
225
|
+
return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
|
226
|
+
|
227
|
+
|
200
228
|
def liger_tvd(
|
201
229
|
input,
|
202
230
|
target,
|
@@ -0,0 +1,234 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
9
|
+
|
10
|
+
|
11
|
+
class LigerFusedNeighborhoodAttention(nn.Module):
|
12
|
+
"""
|
13
|
+
Liger Fused Neighborhood Attention Module.
|
14
|
+
|
15
|
+
Paper: https://arxiv.org/pdf/2504.16922
|
16
|
+
|
17
|
+
Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
|
18
|
+
around each position, reducing computational complexity from O(n²) to O(n*k)
|
19
|
+
where k is the neighborhood size.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
hidden_size (int): The hidden dimension size
|
23
|
+
num_heads (int): Number of attention heads
|
24
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
25
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
26
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
27
|
+
dropout (float): Dropout probability (default: 0.0)
|
28
|
+
scale (Optional[float]): Scaling factor for attention scores.
|
29
|
+
If None, uses 1/sqrt(head_dim) (default: None)
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
hidden_size: int,
|
35
|
+
num_heads: int,
|
36
|
+
kernel_size: int = 7,
|
37
|
+
dilation: int = 1,
|
38
|
+
bias: bool = True,
|
39
|
+
dropout: float = 0.0,
|
40
|
+
scale: Optional[float] = None,
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
|
44
|
+
if hidden_size % num_heads != 0:
|
45
|
+
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
|
46
|
+
|
47
|
+
if kernel_size <= 0:
|
48
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be positive")
|
49
|
+
|
50
|
+
if kernel_size % 2 == 0:
|
51
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be odd")
|
52
|
+
|
53
|
+
if dilation < 1:
|
54
|
+
raise ValueError(f"dilation ({dilation}) must be positive")
|
55
|
+
|
56
|
+
self.hidden_size = hidden_size
|
57
|
+
self.num_heads = num_heads
|
58
|
+
self.head_dim = hidden_size // num_heads
|
59
|
+
self.kernel_size = kernel_size
|
60
|
+
self.dilation = dilation
|
61
|
+
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
|
62
|
+
self.dropout_p = dropout
|
63
|
+
|
64
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
65
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
66
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
67
|
+
|
68
|
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
69
|
+
|
70
|
+
if dropout > 0.0:
|
71
|
+
self.dropout = nn.Dropout(dropout)
|
72
|
+
else:
|
73
|
+
self.dropout = None
|
74
|
+
|
75
|
+
def forward(
|
76
|
+
self,
|
77
|
+
hidden_states: torch.Tensor,
|
78
|
+
attention_mask: Optional[torch.Tensor] = None,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
"""
|
81
|
+
Forward pass of the fused neighborhood attention module.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
85
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
89
|
+
"""
|
90
|
+
if attention_mask is not None:
|
91
|
+
raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
|
92
|
+
|
93
|
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
94
|
+
|
95
|
+
query = self.q_proj(hidden_states)
|
96
|
+
key = self.k_proj(hidden_states)
|
97
|
+
value = self.v_proj(hidden_states)
|
98
|
+
|
99
|
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
100
|
+
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
101
|
+
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
102
|
+
|
103
|
+
attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
|
104
|
+
query, key, value, self.kernel_size, self.dilation, self.scale
|
105
|
+
)
|
106
|
+
|
107
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
|
108
|
+
|
109
|
+
if self.dropout is not None:
|
110
|
+
attn_output = self.dropout(attn_output)
|
111
|
+
|
112
|
+
output = self.out_proj(attn_output)
|
113
|
+
|
114
|
+
return output
|
115
|
+
|
116
|
+
def extra_repr(self) -> str:
|
117
|
+
return (
|
118
|
+
f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
|
119
|
+
f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
|
120
|
+
f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
class LigerFusedNeighborhoodAttentionLayer(nn.Module):
|
125
|
+
"""
|
126
|
+
A complete neighborhood attention layer with layer norm and residual connection.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
hidden_size (int): The hidden dimension size
|
130
|
+
num_heads (int): Number of attention heads
|
131
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
132
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
133
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
134
|
+
dropout (float): Dropout probability (default: 0.0)
|
135
|
+
layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
|
136
|
+
scale (Optional[float]): Scaling factor for attention scores (default: None)
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
hidden_size: int,
|
142
|
+
num_heads: int,
|
143
|
+
kernel_size: int = 7,
|
144
|
+
dilation: int = 1,
|
145
|
+
bias: bool = True,
|
146
|
+
dropout: float = 0.0,
|
147
|
+
layer_norm_eps: float = 1e-5,
|
148
|
+
scale: Optional[float] = None,
|
149
|
+
):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
self.attention = LigerFusedNeighborhoodAttention(
|
153
|
+
hidden_size=hidden_size,
|
154
|
+
num_heads=num_heads,
|
155
|
+
kernel_size=kernel_size,
|
156
|
+
dilation=dilation,
|
157
|
+
bias=bias,
|
158
|
+
dropout=dropout,
|
159
|
+
scale=scale,
|
160
|
+
)
|
161
|
+
|
162
|
+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
163
|
+
|
164
|
+
if dropout > 0.0:
|
165
|
+
self.dropout = nn.Dropout(dropout)
|
166
|
+
else:
|
167
|
+
self.dropout = None
|
168
|
+
|
169
|
+
def forward(
|
170
|
+
self,
|
171
|
+
hidden_states: torch.Tensor,
|
172
|
+
attention_mask: Optional[torch.Tensor] = None,
|
173
|
+
) -> torch.Tensor:
|
174
|
+
"""
|
175
|
+
Forward pass with residual connection and layer normalization.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
179
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
183
|
+
"""
|
184
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
185
|
+
|
186
|
+
attn_output = self.attention(normed_hidden_states, attention_mask)
|
187
|
+
|
188
|
+
if self.dropout is not None:
|
189
|
+
attn_output = self.dropout(attn_output)
|
190
|
+
|
191
|
+
output = hidden_states + attn_output
|
192
|
+
|
193
|
+
return output
|
194
|
+
|
195
|
+
|
196
|
+
class LigerFusedNeighborhoodAttentionConfig:
|
197
|
+
"""
|
198
|
+
Configuration class for Fused Neighborhood Attention.
|
199
|
+
|
200
|
+
This can be used to easily configure neighborhood attention parameters
|
201
|
+
for different model architectures.
|
202
|
+
"""
|
203
|
+
|
204
|
+
def __init__(
|
205
|
+
self,
|
206
|
+
hidden_size: int = 768,
|
207
|
+
num_heads: int = 12,
|
208
|
+
kernel_size: int = 7,
|
209
|
+
dilation: int = 1,
|
210
|
+
bias: bool = True,
|
211
|
+
dropout: float = 0.0,
|
212
|
+
layer_norm_eps: float = 1e-5,
|
213
|
+
scale: Optional[float] = None,
|
214
|
+
):
|
215
|
+
self.hidden_size = hidden_size
|
216
|
+
self.num_heads = num_heads
|
217
|
+
self.kernel_size = kernel_size
|
218
|
+
self.dilation = dilation
|
219
|
+
self.bias = bias
|
220
|
+
self.dropout = dropout
|
221
|
+
self.layer_norm_eps = layer_norm_eps
|
222
|
+
self.scale = scale
|
223
|
+
|
224
|
+
def to_dict(self):
|
225
|
+
return {
|
226
|
+
"hidden_size": self.hidden_size,
|
227
|
+
"num_heads": self.num_heads,
|
228
|
+
"kernel_size": self.kernel_size,
|
229
|
+
"dilation": self.dilation,
|
230
|
+
"bias": self.bias,
|
231
|
+
"dropout": self.dropout,
|
232
|
+
"layer_norm_eps": self.layer_norm_eps,
|
233
|
+
"scale": self.scale,
|
234
|
+
}
|
@@ -20,6 +20,7 @@ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCR
|
|
20
20
|
liger_kernel/ops/dyt.py,sha256=Y180EIvtUc2z83mhyub0EVOCQHJmWX3JnscqkOJqswk,5467
|
21
21
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
23
|
+
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
23
24
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
25
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
25
26
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
@@ -42,9 +43,10 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
|
|
42
43
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
43
44
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
44
45
|
liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
|
45
|
-
liger_kernel/transformers/functional.py,sha256=
|
46
|
+
liger_kernel/transformers/functional.py,sha256=7Emw7D6VPMg8hfasC33NiolvKmQVF1gV6VayKQCEWJM,7446
|
46
47
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
47
48
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
49
|
+
liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
|
48
50
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
49
51
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
50
52
|
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
@@ -85,9 +87,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
85
87
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
86
88
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
87
89
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
88
|
-
liger_kernel_nightly-0.5.10.
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/METADATA,sha256=764bP9ZeY2N1vaqqE92BFRP9RvdR9a8k10FwZs4XS9A,24309
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250606182408.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|