liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107181130__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +12 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/rms_norm.py +126 -49
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/gemma3.py +1 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +118 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +3 -3
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +27 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/METADATA +9 -3
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/RECORD +58 -46
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UB-aware GEGLU implementation for Ascend NPU.
|
|
3
|
+
|
|
4
|
+
This implementation automatically adjusts block sizes to fit within UB constraints,
|
|
5
|
+
preventing UB overflow errors when running on Ascend NPU.
|
|
6
|
+
|
|
7
|
+
It reuses the original kernels when possible, and only uses tiling when necessary.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import operator
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import triton
|
|
14
|
+
import triton.language as tl
|
|
15
|
+
|
|
16
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
17
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
18
|
+
from liger_kernel.ops.utils import compare_version
|
|
19
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
20
|
+
from liger_kernel.utils import is_npu_available
|
|
21
|
+
|
|
22
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
23
|
+
try:
|
|
24
|
+
from triton.language.extra.libdevice import tanh
|
|
25
|
+
except ModuleNotFoundError:
|
|
26
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
27
|
+
else:
|
|
28
|
+
from triton.language.math import tanh
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@triton.jit
|
|
32
|
+
def _geglu_tanh_forward_kernel_npu(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
33
|
+
"""
|
|
34
|
+
UB-aware GEGLU forward kernel for NPU.
|
|
35
|
+
|
|
36
|
+
Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
|
|
37
|
+
When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
|
|
38
|
+
"""
|
|
39
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
40
|
+
|
|
41
|
+
# locate start index
|
|
42
|
+
a += program_id * stride
|
|
43
|
+
b += program_id * stride
|
|
44
|
+
c += program_id * stride
|
|
45
|
+
|
|
46
|
+
# Process in tiles when BLOCK_SIZE < n_cols
|
|
47
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
48
|
+
col_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
49
|
+
mask = col_offsets < n_cols
|
|
50
|
+
|
|
51
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
52
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
53
|
+
|
|
54
|
+
# tanh approximation form of GELU is computed with:
|
|
55
|
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
|
56
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
57
|
+
a_cubed = a_row * a_row * a_row
|
|
58
|
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
59
|
+
tanh_result = tanh(tanh_arg)
|
|
60
|
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
61
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
62
|
+
|
|
63
|
+
tl.store(c + col_offsets, c_row, mask=mask)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@triton.jit
|
|
67
|
+
def _geglu_tanh_backward_kernel_npu(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
68
|
+
"""
|
|
69
|
+
UB-aware GEGLU backward kernel for NPU.
|
|
70
|
+
|
|
71
|
+
Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
|
|
72
|
+
When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
|
|
73
|
+
"""
|
|
74
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
75
|
+
|
|
76
|
+
# locate start index
|
|
77
|
+
dc += program_id * stride
|
|
78
|
+
a += program_id * stride
|
|
79
|
+
b += program_id * stride
|
|
80
|
+
|
|
81
|
+
# Process in tiles when BLOCK_SIZE < n_cols
|
|
82
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
83
|
+
col_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
84
|
+
mask = col_offsets < n_cols
|
|
85
|
+
|
|
86
|
+
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
|
|
87
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
88
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
89
|
+
|
|
90
|
+
# recomputation to save memory
|
|
91
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
92
|
+
a_cubed = a_row * a_row * a_row
|
|
93
|
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
94
|
+
tanh_result = tanh(tanh_arg)
|
|
95
|
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
96
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
97
|
+
|
|
98
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
99
|
+
|
|
100
|
+
# Gradient w.r.t. a can be computed with:
|
|
101
|
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
102
|
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
103
|
+
term1 = 0.5 * (1 + tanh_result)
|
|
104
|
+
tanh_sq = tanh_result * tanh_result
|
|
105
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
106
|
+
da_row = dc_row * b_row * (term1 + term2)
|
|
107
|
+
|
|
108
|
+
tl.store(a + col_offsets, da_row, mask=mask)
|
|
109
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def geglu_forward(a, b):
|
|
113
|
+
"""
|
|
114
|
+
UB-aware GEGLU forward pass for NPU.
|
|
115
|
+
|
|
116
|
+
Automatically adjusts block size to fit within UB constraints.
|
|
117
|
+
"""
|
|
118
|
+
ori_shape = a.shape
|
|
119
|
+
|
|
120
|
+
n_cols = ori_shape[-1]
|
|
121
|
+
a = a.view(-1, n_cols)
|
|
122
|
+
b = b.view(-1, n_cols)
|
|
123
|
+
c = torch.empty_like(a)
|
|
124
|
+
n_rows = a.shape[0]
|
|
125
|
+
|
|
126
|
+
# Calculate desired block size
|
|
127
|
+
desired_block_size, num_warps = calculate_settings(n_cols)
|
|
128
|
+
|
|
129
|
+
# Compute tiling strategy based on UB capacity
|
|
130
|
+
dtype_size = a.element_size()
|
|
131
|
+
# GEGLU forward tiling strategy:
|
|
132
|
+
# - Calculates maximum safe block size based on UB capacity
|
|
133
|
+
# - Memory analysis:
|
|
134
|
+
# * Inputs: a, b
|
|
135
|
+
# * Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
|
|
136
|
+
# * Output: c
|
|
137
|
+
# * Total: ~7x * BLOCK_SIZE * dtype_size
|
|
138
|
+
# - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
|
|
139
|
+
# - shapes: ((n_cols,),)
|
|
140
|
+
# - tiling_dims: (0,) means first dimension can be tiled
|
|
141
|
+
# - Returns: ((block_size,),)
|
|
142
|
+
shapes = ((n_cols,),)
|
|
143
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
144
|
+
safety_margin=0.80,
|
|
145
|
+
dtype_size=dtype_size,
|
|
146
|
+
memory_multiplier=7.0,
|
|
147
|
+
shapes=shapes,
|
|
148
|
+
tiling_dims=(0,),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
|
|
152
|
+
# Strategy returns ((block_size,),)
|
|
153
|
+
adjusted_block_size = tile_shapes[0][0]
|
|
154
|
+
else:
|
|
155
|
+
# Fallback to desired block size if no best practice found (no tiling needed)
|
|
156
|
+
adjusted_block_size = desired_block_size
|
|
157
|
+
# Always use the unified NPU kernel
|
|
158
|
+
# When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
|
|
159
|
+
# When adjusted_block_size < n_cols, the loop handles tiling automatically
|
|
160
|
+
_geglu_tanh_forward_kernel_npu[(n_rows,)](
|
|
161
|
+
a,
|
|
162
|
+
b,
|
|
163
|
+
c,
|
|
164
|
+
c.stride(-2),
|
|
165
|
+
n_cols=n_cols,
|
|
166
|
+
BLOCK_SIZE=adjusted_block_size,
|
|
167
|
+
num_warps=num_warps,
|
|
168
|
+
)
|
|
169
|
+
return a, b, c.view(*ori_shape)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def geglu_backward(a, b, dc):
|
|
173
|
+
"""
|
|
174
|
+
UB-aware GEGLU backward pass for NPU.
|
|
175
|
+
|
|
176
|
+
Automatically adjusts block size to fit within UB constraints.
|
|
177
|
+
"""
|
|
178
|
+
ori_shape = dc.shape
|
|
179
|
+
n_cols = ori_shape[-1]
|
|
180
|
+
dc = dc.view(-1, n_cols)
|
|
181
|
+
n_rows = dc.shape[0]
|
|
182
|
+
|
|
183
|
+
# Calculate desired block size
|
|
184
|
+
desired_block_size, num_warps = calculate_settings(n_cols)
|
|
185
|
+
|
|
186
|
+
# Compute tiling strategy based on UB capacity
|
|
187
|
+
dtype_size = dc.element_size()
|
|
188
|
+
# GEGLU backward tiling strategy:
|
|
189
|
+
# - Calculates maximum safe block size based on UB capacity
|
|
190
|
+
# - Memory analysis:
|
|
191
|
+
# * More intermediates for gradient computation compared to forward
|
|
192
|
+
# * Total: ~10x * BLOCK_SIZE * dtype_size
|
|
193
|
+
# - Uses memory_multiplier=10.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
|
|
194
|
+
# - shapes: ((n_cols,),)
|
|
195
|
+
# - tiling_dims: (0,) means first dimension can be tiled
|
|
196
|
+
# - Returns: ((block_size,),)
|
|
197
|
+
shapes = ((n_cols,),)
|
|
198
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
199
|
+
safety_margin=0.80,
|
|
200
|
+
dtype_size=dtype_size,
|
|
201
|
+
memory_multiplier=10.0,
|
|
202
|
+
shapes=shapes,
|
|
203
|
+
tiling_dims=(0,),
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
|
|
207
|
+
# Strategy returns ((block_size,),)
|
|
208
|
+
adjusted_block_size = tile_shapes[0][0]
|
|
209
|
+
else:
|
|
210
|
+
# Fallback to desired block size if no best practice found (no tiling needed)
|
|
211
|
+
adjusted_block_size = desired_block_size
|
|
212
|
+
|
|
213
|
+
# Always use the unified NPU kernel
|
|
214
|
+
# When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
|
|
215
|
+
# When adjusted_block_size < n_cols, the loop handles tiling automatically
|
|
216
|
+
_geglu_tanh_backward_kernel_npu[(n_rows,)](
|
|
217
|
+
dc,
|
|
218
|
+
a,
|
|
219
|
+
b,
|
|
220
|
+
dc.stride(-2),
|
|
221
|
+
n_cols=n_cols,
|
|
222
|
+
BLOCK_SIZE=adjusted_block_size,
|
|
223
|
+
num_warps=num_warps,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
230
|
+
"""UB-aware GEGLU function for Ascend NPU."""
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
@ensure_contiguous
|
|
234
|
+
def forward(ctx, a, b):
|
|
235
|
+
a, b, c = geglu_forward(a, b)
|
|
236
|
+
ctx.save_for_backward(a, b)
|
|
237
|
+
return c
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
@ensure_contiguous
|
|
241
|
+
def backward(ctx, dc):
|
|
242
|
+
a, b = ctx.saved_tensors
|
|
243
|
+
a, b = geglu_backward(a, b, dc)
|
|
244
|
+
return a, b
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def _triton_qwen2vl_mrope_npu(
|
|
10
|
+
q_ptr,
|
|
11
|
+
q_row_stride,
|
|
12
|
+
k_ptr,
|
|
13
|
+
k_row_stride,
|
|
14
|
+
cos,
|
|
15
|
+
sin,
|
|
16
|
+
sl,
|
|
17
|
+
bs: tl.constexpr,
|
|
18
|
+
n_qh: tl.constexpr,
|
|
19
|
+
n_kh: tl.constexpr,
|
|
20
|
+
hd: tl.constexpr,
|
|
21
|
+
mrope_section_t: tl.constexpr,
|
|
22
|
+
mrope_section_h: tl.constexpr,
|
|
23
|
+
BLOCK_Q: tl.constexpr,
|
|
24
|
+
BLOCK_K: tl.constexpr,
|
|
25
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
26
|
+
):
|
|
27
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
28
|
+
|
|
29
|
+
t_end = mrope_section_t
|
|
30
|
+
h_end = t_end + mrope_section_h
|
|
31
|
+
|
|
32
|
+
t_cos = cos + pid * hd
|
|
33
|
+
h_cos = t_cos + bs * sl * hd
|
|
34
|
+
w_cos = h_cos + bs * sl * hd
|
|
35
|
+
t_sin = sin + pid * hd
|
|
36
|
+
h_sin = t_sin + bs * sl * hd
|
|
37
|
+
w_sin = h_sin + bs * sl * hd
|
|
38
|
+
|
|
39
|
+
q_base = q_ptr + pid * q_row_stride
|
|
40
|
+
k_base = k_ptr + pid * k_row_stride
|
|
41
|
+
|
|
42
|
+
d_idx = tl.arange(0, hd // 2)
|
|
43
|
+
d_mask = d_idx < (hd // 2)
|
|
44
|
+
|
|
45
|
+
pos_mask_t = d_idx < t_end
|
|
46
|
+
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
|
|
47
|
+
|
|
48
|
+
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
|
|
49
|
+
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
|
|
50
|
+
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
|
|
51
|
+
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
|
|
52
|
+
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
|
|
53
|
+
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
|
|
54
|
+
|
|
55
|
+
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
|
|
56
|
+
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
|
|
57
|
+
|
|
58
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
59
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
60
|
+
qh_mask = qh_idx < n_qh
|
|
61
|
+
|
|
62
|
+
block_mask = qh_mask[:, None] & d_mask[None, :]
|
|
63
|
+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
|
|
64
|
+
|
|
65
|
+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
|
|
66
|
+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
67
|
+
|
|
68
|
+
if not BACKWARD_PASS:
|
|
69
|
+
new_left = q_left * cos_vals - q_right * sin_vals
|
|
70
|
+
new_right = q_right * cos_vals + q_left * sin_vals
|
|
71
|
+
else:
|
|
72
|
+
new_left = q_left * cos_vals + q_right * sin_vals
|
|
73
|
+
new_right = q_right * cos_vals - q_left * sin_vals
|
|
74
|
+
|
|
75
|
+
tl.store(q_base + offsets, new_left, mask=block_mask)
|
|
76
|
+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
77
|
+
|
|
78
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
79
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
80
|
+
kh_mask = kh_idx < n_kh
|
|
81
|
+
|
|
82
|
+
block_mask = kh_mask[:, None] & d_mask[None, :]
|
|
83
|
+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
|
|
84
|
+
|
|
85
|
+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
|
|
86
|
+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
87
|
+
|
|
88
|
+
if not BACKWARD_PASS:
|
|
89
|
+
new_left = k_left * cos_vals - k_right * sin_vals
|
|
90
|
+
new_right = k_right * cos_vals + k_left * sin_vals
|
|
91
|
+
else:
|
|
92
|
+
new_left = k_left * cos_vals + k_right * sin_vals
|
|
93
|
+
new_right = k_right * cos_vals - k_left * sin_vals
|
|
94
|
+
|
|
95
|
+
tl.store(k_base + offsets, new_left, mask=block_mask)
|
|
96
|
+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
100
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
101
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
102
|
+
q = q.transpose(1, 2)
|
|
103
|
+
k = k.transpose(1, 2)
|
|
104
|
+
|
|
105
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
106
|
+
n_kv_head = k.shape[2]
|
|
107
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
108
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
109
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
110
|
+
|
|
111
|
+
n_row = batch_size * seq_len
|
|
112
|
+
|
|
113
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
114
|
+
q = q.contiguous()
|
|
115
|
+
k = k.contiguous()
|
|
116
|
+
cos = cos.contiguous()
|
|
117
|
+
sin = sin.contiguous()
|
|
118
|
+
|
|
119
|
+
# Compute tiling strategy based on UB capacity
|
|
120
|
+
dtype_size = q.element_size()
|
|
121
|
+
# MROPE forward tiling strategy:
|
|
122
|
+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
|
|
123
|
+
# - In q heads loop (peak memory):
|
|
124
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
125
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
126
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
127
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
128
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
129
|
+
# - In k heads loop (peak memory):
|
|
130
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
131
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
132
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
133
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
134
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
135
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
136
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
137
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
138
|
+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
|
|
139
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
140
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
141
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
142
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
143
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
144
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
145
|
+
safety_margin=0.90,
|
|
146
|
+
dtype_size=dtype_size,
|
|
147
|
+
memory_multiplier=3.0,
|
|
148
|
+
shapes=shapes,
|
|
149
|
+
tiling_dims=(0, 0),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
153
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
154
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
155
|
+
BLOCK_Q, _ = q_tile_shape
|
|
156
|
+
BLOCK_K, _ = k_tile_shape
|
|
157
|
+
else:
|
|
158
|
+
# Fallback to conservative defaults
|
|
159
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
160
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
161
|
+
_triton_qwen2vl_mrope_npu[(n_row,)](
|
|
162
|
+
q,
|
|
163
|
+
q.stride(1),
|
|
164
|
+
k,
|
|
165
|
+
k.stride(1),
|
|
166
|
+
cos,
|
|
167
|
+
sin,
|
|
168
|
+
seq_len,
|
|
169
|
+
batch_size,
|
|
170
|
+
n_q_head,
|
|
171
|
+
n_kv_head,
|
|
172
|
+
head_dim,
|
|
173
|
+
mrope_section[0],
|
|
174
|
+
mrope_section[1],
|
|
175
|
+
BLOCK_Q,
|
|
176
|
+
BLOCK_K,
|
|
177
|
+
BACKWARD_PASS=False,
|
|
178
|
+
)
|
|
179
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
183
|
+
dq = dq.transpose(1, 2)
|
|
184
|
+
dk = dk.transpose(1, 2)
|
|
185
|
+
|
|
186
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
187
|
+
n_kv_head = dk.shape[2]
|
|
188
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
189
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
190
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
191
|
+
|
|
192
|
+
n_row = batch_size * seq_len
|
|
193
|
+
|
|
194
|
+
# ensure dq and dk are contiguous
|
|
195
|
+
dq = dq.contiguous()
|
|
196
|
+
dk = dk.contiguous()
|
|
197
|
+
|
|
198
|
+
# Compute tiling strategy based on UB capacity
|
|
199
|
+
dtype_size = dq.element_size()
|
|
200
|
+
# MROPE backward tiling strategy:
|
|
201
|
+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
|
|
202
|
+
# - In q heads loop (peak memory):
|
|
203
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
204
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
205
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
206
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
207
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
208
|
+
# - In k heads loop (peak memory):
|
|
209
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
210
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
211
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
212
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
213
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
214
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
215
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
216
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
217
|
+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
|
|
218
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
219
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
220
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
221
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
222
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
223
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
224
|
+
safety_margin=0.90,
|
|
225
|
+
dtype_size=dtype_size,
|
|
226
|
+
memory_multiplier=3.0,
|
|
227
|
+
shapes=shapes,
|
|
228
|
+
tiling_dims=(0, 0),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
232
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
233
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
234
|
+
BLOCK_Q, _ = q_tile_shape
|
|
235
|
+
BLOCK_K, _ = k_tile_shape
|
|
236
|
+
else:
|
|
237
|
+
# Fallback to conservative defaults
|
|
238
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
239
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
240
|
+
_triton_qwen2vl_mrope_npu[(n_row,)](
|
|
241
|
+
dq,
|
|
242
|
+
dq.stride(1),
|
|
243
|
+
dk,
|
|
244
|
+
dk.stride(1),
|
|
245
|
+
cos,
|
|
246
|
+
sin,
|
|
247
|
+
seq_len,
|
|
248
|
+
batch_size,
|
|
249
|
+
n_q_head,
|
|
250
|
+
n_kv_head,
|
|
251
|
+
head_dim,
|
|
252
|
+
mrope_section[0],
|
|
253
|
+
mrope_section[1],
|
|
254
|
+
BLOCK_Q,
|
|
255
|
+
BLOCK_K,
|
|
256
|
+
BACKWARD_PASS=True,
|
|
257
|
+
)
|
|
258
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
262
|
+
@staticmethod
|
|
263
|
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
264
|
+
"""
|
|
265
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
266
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
267
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
268
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
269
|
+
"""
|
|
270
|
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
271
|
+
ctx.save_for_backward(cos, sin)
|
|
272
|
+
ctx.mrope_section = mrope_section
|
|
273
|
+
return q, k
|
|
274
|
+
|
|
275
|
+
def backward(ctx, dq, dk):
|
|
276
|
+
"""
|
|
277
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
278
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
279
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
280
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
281
|
+
"""
|
|
282
|
+
cos, sin = ctx.saved_tensors
|
|
283
|
+
mrope_section = ctx.mrope_section
|
|
284
|
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
|
285
|
+
return dq, dk, None, None, None, None
|