liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +71 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +89 -69
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +25 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +44 -26
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +17 -8
- liger_kernel/transformers/model/gemma3.py +35 -16
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +37 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +584 -49
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.3.dist-info/RECORD +0 -111
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
|
|
9
|
+
"""
|
|
10
|
+
Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
|
|
11
|
+
|
|
12
|
+
Supports:
|
|
13
|
+
- complex freqs: (..., head_dim_half) complex -> real/imag
|
|
14
|
+
- packed freqs: (..., 2*head_dim_half) real -> split into real/imag
|
|
15
|
+
"""
|
|
16
|
+
if freqs_cis.is_complex():
|
|
17
|
+
freqs_real = freqs_cis.real
|
|
18
|
+
freqs_imag = freqs_cis.imag
|
|
19
|
+
else:
|
|
20
|
+
if freqs_cis.shape[-1] == 2 * head_dim_half:
|
|
21
|
+
freqs_real = freqs_cis[..., :head_dim_half]
|
|
22
|
+
freqs_imag = freqs_cis[..., head_dim_half:]
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
|
|
26
|
+
f"expected last dim = {2 * head_dim_half}"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if freqs_real.shape[-1] != head_dim_half:
|
|
30
|
+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
|
|
31
|
+
|
|
32
|
+
# Flatten leading dims -> (N, head_dim_half)
|
|
33
|
+
freqs_real = freqs_real.reshape(-1, head_dim_half)
|
|
34
|
+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
|
|
35
|
+
|
|
36
|
+
# Broadcast/slice to (seq_len, head_dim_half)
|
|
37
|
+
if freqs_real.shape[0] < seq_len:
|
|
38
|
+
if freqs_real.shape[0] == 1:
|
|
39
|
+
freqs_real = freqs_real.expand(seq_len, -1)
|
|
40
|
+
freqs_imag = freqs_imag.expand(seq_len, -1)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
|
|
43
|
+
elif freqs_real.shape[0] > seq_len:
|
|
44
|
+
freqs_real = freqs_real[:seq_len]
|
|
45
|
+
freqs_imag = freqs_imag[:seq_len]
|
|
46
|
+
|
|
47
|
+
return freqs_real, freqs_imag
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
|
|
51
|
+
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
|
|
52
|
+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
|
|
53
|
+
|
|
54
|
+
if k.dtype != q.dtype:
|
|
55
|
+
k = k.to(q.dtype)
|
|
56
|
+
|
|
57
|
+
q = q.to(compute_dtype).contiguous()
|
|
58
|
+
k = k.to(compute_dtype).contiguous()
|
|
59
|
+
freqs_real = freqs_real.to(compute_dtype).contiguous()
|
|
60
|
+
freqs_imag = freqs_imag.to(compute_dtype).contiguous()
|
|
61
|
+
return q, k, freqs_real, freqs_imag, compute_dtype
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@triton.jit
|
|
65
|
+
def _triton_llama4_rope_npu(
|
|
66
|
+
q_ptr,
|
|
67
|
+
k_ptr,
|
|
68
|
+
freqs_real_ptr,
|
|
69
|
+
freqs_imag_ptr,
|
|
70
|
+
q_row_stride,
|
|
71
|
+
k_row_stride,
|
|
72
|
+
q_head_stride,
|
|
73
|
+
k_head_stride,
|
|
74
|
+
freqs_row_stride,
|
|
75
|
+
sl,
|
|
76
|
+
bs: tl.constexpr,
|
|
77
|
+
n_qh: tl.constexpr,
|
|
78
|
+
n_kh: tl.constexpr,
|
|
79
|
+
hd: tl.constexpr,
|
|
80
|
+
BLOCK_Q: tl.constexpr,
|
|
81
|
+
BLOCK_K: tl.constexpr,
|
|
82
|
+
imag_sign: tl.constexpr,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Llama4 RoPE on Ascend NPU for interleaved complex layout:
|
|
86
|
+
- q/k shape: (bs, sl, n_heads, hd)
|
|
87
|
+
- last dim layout: [real0, imag0, real1, imag1, ...]
|
|
88
|
+
- freqs_real/imag: (sl, hd//2)
|
|
89
|
+
"""
|
|
90
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
91
|
+
batch_idx = pid // sl
|
|
92
|
+
seq_idx = pid % sl
|
|
93
|
+
|
|
94
|
+
if batch_idx >= bs:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
q_base = q_ptr + pid * q_row_stride
|
|
98
|
+
k_base = k_ptr + pid * k_row_stride
|
|
99
|
+
|
|
100
|
+
freq_base = seq_idx * freqs_row_stride
|
|
101
|
+
hd_idx = tl.arange(0, hd)
|
|
102
|
+
hd_mask = hd_idx < (hd)
|
|
103
|
+
|
|
104
|
+
freq_idx = tl.arange(0, hd // 2)
|
|
105
|
+
freq_mask = freq_idx < (hd // 2)
|
|
106
|
+
|
|
107
|
+
freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
|
|
108
|
+
freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
|
|
109
|
+
|
|
110
|
+
# Q heads (chunked for UB)
|
|
111
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
112
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
113
|
+
qh_mask = qh_idx < n_qh
|
|
114
|
+
block_mask = qh_mask[:, None] & hd_mask[None, :]
|
|
115
|
+
|
|
116
|
+
head_ptr = q_base + qh_idx[:, None] * q_head_stride
|
|
117
|
+
|
|
118
|
+
q_pair = tl.load(
|
|
119
|
+
head_ptr + hd_idx[None, :],
|
|
120
|
+
mask=block_mask,
|
|
121
|
+
other=0.0,
|
|
122
|
+
)
|
|
123
|
+
q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
|
|
124
|
+
q_real, q_imag = tl.split(q_pair)
|
|
125
|
+
|
|
126
|
+
new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
|
|
127
|
+
new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
|
|
128
|
+
new_q_pair = tl.interleave(new_real, new_imag)
|
|
129
|
+
|
|
130
|
+
tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask)
|
|
131
|
+
|
|
132
|
+
# K heads (chunked for UB)
|
|
133
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
134
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
135
|
+
kh_mask = kh_idx < n_kh
|
|
136
|
+
block_mask = kh_mask[:, None] & hd_mask[None, :]
|
|
137
|
+
|
|
138
|
+
head_ptr = k_base + kh_idx[:, None] * k_head_stride
|
|
139
|
+
|
|
140
|
+
k_pair = tl.load(
|
|
141
|
+
head_ptr + hd_idx[None, :],
|
|
142
|
+
mask=block_mask,
|
|
143
|
+
other=0.0,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
|
|
147
|
+
k_real, k_imag = tl.split(k_pair)
|
|
148
|
+
|
|
149
|
+
new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
|
|
150
|
+
new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
|
|
151
|
+
new_k_pair = tl.interleave(new_real, new_imag)
|
|
152
|
+
|
|
153
|
+
tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def llama4_rope_forward(q, k, freqs_cis):
|
|
157
|
+
"""
|
|
158
|
+
Ascend NPU implementation of Llama4 RoPE.
|
|
159
|
+
|
|
160
|
+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
|
|
161
|
+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
|
|
162
|
+
"""
|
|
163
|
+
original_dtype = q.dtype
|
|
164
|
+
|
|
165
|
+
bs, sl, n_qh, hd = q.shape
|
|
166
|
+
_, _, n_kh, _ = k.shape
|
|
167
|
+
if hd % 2 != 0:
|
|
168
|
+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
|
|
169
|
+
hd_half = hd // 2
|
|
170
|
+
|
|
171
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
|
|
172
|
+
q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
|
|
173
|
+
|
|
174
|
+
# UB tiling strategy: tile heads dimension only
|
|
175
|
+
dtype_size = q.element_size()
|
|
176
|
+
shapes = ((n_qh, hd), (n_kh, hd))
|
|
177
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
178
|
+
safety_margin=0.90,
|
|
179
|
+
dtype_size=dtype_size,
|
|
180
|
+
memory_multiplier=12.0,
|
|
181
|
+
shapes=shapes,
|
|
182
|
+
tiling_dims=(0, 0),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
186
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
187
|
+
BLOCK_Q, _ = q_tile_shape
|
|
188
|
+
BLOCK_K, _ = k_tile_shape
|
|
189
|
+
else:
|
|
190
|
+
BLOCK_Q = triton.next_power_of_2(n_qh)
|
|
191
|
+
BLOCK_K = triton.next_power_of_2(n_kh)
|
|
192
|
+
|
|
193
|
+
n_row = bs * sl
|
|
194
|
+
|
|
195
|
+
_triton_llama4_rope_npu[(n_row,)](
|
|
196
|
+
q,
|
|
197
|
+
k,
|
|
198
|
+
freqs_real,
|
|
199
|
+
freqs_imag,
|
|
200
|
+
q.stride(1),
|
|
201
|
+
k.stride(1),
|
|
202
|
+
q.stride(2),
|
|
203
|
+
k.stride(2),
|
|
204
|
+
freqs_real.stride(0),
|
|
205
|
+
sl,
|
|
206
|
+
bs,
|
|
207
|
+
n_qh,
|
|
208
|
+
n_kh,
|
|
209
|
+
hd,
|
|
210
|
+
BLOCK_Q,
|
|
211
|
+
BLOCK_K,
|
|
212
|
+
imag_sign=1.0,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if compute_dtype != original_dtype:
|
|
216
|
+
q = q.to(original_dtype)
|
|
217
|
+
k = k.to(original_dtype)
|
|
218
|
+
return q, k
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def llama4_rope_backward(dq, dk, freqs_cis):
|
|
222
|
+
"""
|
|
223
|
+
Ascend NPU implementation of Llama4 RoPE.
|
|
224
|
+
|
|
225
|
+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
|
|
226
|
+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
|
|
227
|
+
"""
|
|
228
|
+
original_dtype = dq.dtype
|
|
229
|
+
|
|
230
|
+
bs, sl, n_qh, hd = dq.shape
|
|
231
|
+
_, _, n_kh, _ = dk.shape
|
|
232
|
+
if hd % 2 != 0:
|
|
233
|
+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
|
|
234
|
+
hd_half = hd // 2
|
|
235
|
+
|
|
236
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
|
|
237
|
+
dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
|
|
238
|
+
|
|
239
|
+
# UB tiling strategy: tile heads dimension only
|
|
240
|
+
dtype_size = dq.element_size()
|
|
241
|
+
shapes = ((n_qh, hd), (n_kh, hd))
|
|
242
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
243
|
+
safety_margin=0.90,
|
|
244
|
+
dtype_size=dtype_size,
|
|
245
|
+
memory_multiplier=12.0,
|
|
246
|
+
shapes=shapes,
|
|
247
|
+
tiling_dims=(0, 0),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
251
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
252
|
+
BLOCK_Q, _ = q_tile_shape
|
|
253
|
+
BLOCK_K, _ = k_tile_shape
|
|
254
|
+
else:
|
|
255
|
+
BLOCK_Q = triton.next_power_of_2(n_qh)
|
|
256
|
+
BLOCK_K = triton.next_power_of_2(n_kh)
|
|
257
|
+
|
|
258
|
+
n_row = bs * sl
|
|
259
|
+
|
|
260
|
+
_triton_llama4_rope_npu[(n_row,)](
|
|
261
|
+
dq,
|
|
262
|
+
dk,
|
|
263
|
+
freqs_real,
|
|
264
|
+
freqs_imag,
|
|
265
|
+
dq.stride(1),
|
|
266
|
+
dk.stride(1),
|
|
267
|
+
dq.stride(2),
|
|
268
|
+
dk.stride(2),
|
|
269
|
+
freqs_real.stride(0),
|
|
270
|
+
sl,
|
|
271
|
+
bs,
|
|
272
|
+
n_qh,
|
|
273
|
+
n_kh,
|
|
274
|
+
hd,
|
|
275
|
+
BLOCK_Q,
|
|
276
|
+
BLOCK_K,
|
|
277
|
+
imag_sign=-1.0,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if compute_dtype != original_dtype:
|
|
281
|
+
dq = dq.to(original_dtype)
|
|
282
|
+
dk = dk.to(original_dtype)
|
|
283
|
+
return dq, dk
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class LigerLlama4RopeFunction(torch.autograd.Function):
|
|
287
|
+
@staticmethod
|
|
288
|
+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
|
|
289
|
+
# BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
|
|
290
|
+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
|
|
291
|
+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
|
|
292
|
+
return q_out, k_out
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
def backward(ctx, dq, dk):
|
|
296
|
+
(freqs_cis,) = ctx.saved_tensors
|
|
297
|
+
dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
|
|
298
|
+
return dq_out, dk_out, None, None
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@triton.jit
|
|
10
|
+
def _triton_qwen2vl_mrope_npu(
|
|
11
|
+
q_ptr,
|
|
12
|
+
q_row_stride,
|
|
13
|
+
k_ptr,
|
|
14
|
+
k_row_stride,
|
|
15
|
+
cos,
|
|
16
|
+
sin,
|
|
17
|
+
sl,
|
|
18
|
+
bs: tl.constexpr,
|
|
19
|
+
total_rows: tl.constexpr,
|
|
20
|
+
n_qh: tl.constexpr,
|
|
21
|
+
n_kh: tl.constexpr,
|
|
22
|
+
hd: tl.constexpr,
|
|
23
|
+
mrope_section_t: tl.constexpr,
|
|
24
|
+
mrope_section_h: tl.constexpr,
|
|
25
|
+
BLOCK_Q: tl.constexpr,
|
|
26
|
+
BLOCK_K: tl.constexpr,
|
|
27
|
+
NUM_STAGES: tl.constexpr,
|
|
28
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
29
|
+
):
|
|
30
|
+
program_id = tl.program_id(0)
|
|
31
|
+
num_programs = tl.num_programs(0)
|
|
32
|
+
|
|
33
|
+
rows_per_program = (total_rows + num_programs - 1) // num_programs
|
|
34
|
+
start_row = program_id * rows_per_program
|
|
35
|
+
actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
|
|
36
|
+
|
|
37
|
+
for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
|
|
38
|
+
pid = start_row + row_offset
|
|
39
|
+
|
|
40
|
+
t_end = mrope_section_t
|
|
41
|
+
h_end = t_end + mrope_section_h
|
|
42
|
+
|
|
43
|
+
t_cos = cos + pid * hd
|
|
44
|
+
h_cos = t_cos + bs * sl * hd
|
|
45
|
+
w_cos = h_cos + bs * sl * hd
|
|
46
|
+
t_sin = sin + pid * hd
|
|
47
|
+
h_sin = t_sin + bs * sl * hd
|
|
48
|
+
w_sin = h_sin + bs * sl * hd
|
|
49
|
+
|
|
50
|
+
q_base = q_ptr + pid * q_row_stride
|
|
51
|
+
k_base = k_ptr + pid * k_row_stride
|
|
52
|
+
|
|
53
|
+
d_idx = tl.arange(0, hd // 2)
|
|
54
|
+
d_mask = d_idx < (hd // 2)
|
|
55
|
+
|
|
56
|
+
pos_mask_t = d_idx < t_end
|
|
57
|
+
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
|
|
58
|
+
|
|
59
|
+
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
|
|
60
|
+
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
|
|
61
|
+
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
|
|
62
|
+
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
|
|
63
|
+
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
|
|
64
|
+
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
|
|
65
|
+
|
|
66
|
+
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
|
|
67
|
+
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
|
|
68
|
+
|
|
69
|
+
# Process q heads in chunks to prevent UB overflow
|
|
70
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
71
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
72
|
+
qh_mask = qh_idx < n_qh
|
|
73
|
+
|
|
74
|
+
block_mask = qh_mask[:, None] & d_mask[None, :]
|
|
75
|
+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
|
|
76
|
+
|
|
77
|
+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
|
|
78
|
+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
79
|
+
|
|
80
|
+
if not BACKWARD_PASS:
|
|
81
|
+
new_left = q_left * cos_vals - q_right * sin_vals
|
|
82
|
+
new_right = q_right * cos_vals + q_left * sin_vals
|
|
83
|
+
else:
|
|
84
|
+
new_left = q_left * cos_vals + q_right * sin_vals
|
|
85
|
+
new_right = q_right * cos_vals - q_left * sin_vals
|
|
86
|
+
|
|
87
|
+
tl.store(q_base + offsets, new_left, mask=block_mask)
|
|
88
|
+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
89
|
+
|
|
90
|
+
# Process k heads in chunks to prevent UB overflow
|
|
91
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
92
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
93
|
+
kh_mask = kh_idx < n_kh
|
|
94
|
+
|
|
95
|
+
block_mask = kh_mask[:, None] & d_mask[None, :]
|
|
96
|
+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
|
|
97
|
+
|
|
98
|
+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
|
|
99
|
+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
100
|
+
|
|
101
|
+
if not BACKWARD_PASS:
|
|
102
|
+
new_left = k_left * cos_vals - k_right * sin_vals
|
|
103
|
+
new_right = k_right * cos_vals + k_left * sin_vals
|
|
104
|
+
else:
|
|
105
|
+
new_left = k_left * cos_vals + k_right * sin_vals
|
|
106
|
+
new_right = k_right * cos_vals - k_left * sin_vals
|
|
107
|
+
|
|
108
|
+
tl.store(k_base + offsets, new_left, mask=block_mask)
|
|
109
|
+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
|
|
113
|
+
# MROPE forward tiling strategy:
|
|
114
|
+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
|
|
115
|
+
# - In q heads loop (peak memory):
|
|
116
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
117
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
118
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
119
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
120
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
121
|
+
# - In k heads loop (peak memory):
|
|
122
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
123
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
124
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
125
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
126
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
127
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
128
|
+
# - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
|
|
129
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
|
|
130
|
+
# - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits
|
|
131
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
132
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
133
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
134
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
135
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
136
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
137
|
+
safety_margin=0.90,
|
|
138
|
+
dtype_size=dtype_size,
|
|
139
|
+
memory_multiplier=3.0,
|
|
140
|
+
shapes=shapes,
|
|
141
|
+
tiling_dims=(0, 0),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
145
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
146
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
147
|
+
BLOCK_Q, _ = q_tile_shape
|
|
148
|
+
BLOCK_K, _ = k_tile_shape
|
|
149
|
+
else:
|
|
150
|
+
# Fallback to conservative defaults
|
|
151
|
+
BLOCK_Q = 2048
|
|
152
|
+
BLOCK_K = 2048
|
|
153
|
+
|
|
154
|
+
return BLOCK_Q, BLOCK_K
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
158
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
159
|
+
q = q.transpose(1, 2)
|
|
160
|
+
k = k.transpose(1, 2)
|
|
161
|
+
|
|
162
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
163
|
+
n_kv_head = k.shape[2]
|
|
164
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
165
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
166
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
167
|
+
|
|
168
|
+
n_row = batch_size * seq_len
|
|
169
|
+
|
|
170
|
+
# ensure tensors passed into the kernel are contiguous
|
|
171
|
+
q = q.contiguous()
|
|
172
|
+
k = k.contiguous()
|
|
173
|
+
cos = cos.contiguous()
|
|
174
|
+
sin = sin.contiguous()
|
|
175
|
+
|
|
176
|
+
dtype_size = q.element_size()
|
|
177
|
+
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
|
|
178
|
+
|
|
179
|
+
num_cores = get_npu_core_count()
|
|
180
|
+
grid_size = min(num_cores, n_row)
|
|
181
|
+
|
|
182
|
+
_triton_qwen2vl_mrope_npu[(grid_size,)](
|
|
183
|
+
q,
|
|
184
|
+
q.stride(1),
|
|
185
|
+
k,
|
|
186
|
+
k.stride(1),
|
|
187
|
+
cos,
|
|
188
|
+
sin,
|
|
189
|
+
seq_len,
|
|
190
|
+
batch_size,
|
|
191
|
+
n_row,
|
|
192
|
+
n_q_head,
|
|
193
|
+
n_kv_head,
|
|
194
|
+
head_dim,
|
|
195
|
+
mrope_section[0],
|
|
196
|
+
mrope_section[1],
|
|
197
|
+
BLOCK_Q,
|
|
198
|
+
BLOCK_K,
|
|
199
|
+
NUM_STAGES=3,
|
|
200
|
+
BACKWARD_PASS=False,
|
|
201
|
+
)
|
|
202
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
206
|
+
dq = dq.transpose(1, 2)
|
|
207
|
+
dk = dk.transpose(1, 2)
|
|
208
|
+
|
|
209
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
210
|
+
n_kv_head = dk.shape[2]
|
|
211
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
212
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
213
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
214
|
+
|
|
215
|
+
n_row = batch_size * seq_len
|
|
216
|
+
|
|
217
|
+
# ensure dq and dk are contiguous
|
|
218
|
+
dq = dq.contiguous()
|
|
219
|
+
dk = dk.contiguous()
|
|
220
|
+
|
|
221
|
+
dtype_size = dq.element_size()
|
|
222
|
+
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
|
|
223
|
+
|
|
224
|
+
num_cores = get_npu_core_count()
|
|
225
|
+
grid_size = min(num_cores, n_row)
|
|
226
|
+
|
|
227
|
+
_triton_qwen2vl_mrope_npu[(grid_size,)](
|
|
228
|
+
dq,
|
|
229
|
+
dq.stride(1),
|
|
230
|
+
dk,
|
|
231
|
+
dk.stride(1),
|
|
232
|
+
cos,
|
|
233
|
+
sin,
|
|
234
|
+
seq_len,
|
|
235
|
+
batch_size,
|
|
236
|
+
n_row,
|
|
237
|
+
n_q_head,
|
|
238
|
+
n_kv_head,
|
|
239
|
+
head_dim,
|
|
240
|
+
mrope_section[0],
|
|
241
|
+
mrope_section[1],
|
|
242
|
+
BLOCK_Q,
|
|
243
|
+
BLOCK_K,
|
|
244
|
+
NUM_STAGES=3,
|
|
245
|
+
BACKWARD_PASS=True,
|
|
246
|
+
)
|
|
247
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
251
|
+
@staticmethod
|
|
252
|
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
253
|
+
"""
|
|
254
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
255
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
256
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
257
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
258
|
+
"""
|
|
259
|
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
260
|
+
ctx.save_for_backward(cos, sin)
|
|
261
|
+
ctx.mrope_section = mrope_section
|
|
262
|
+
return q, k
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def backward(ctx, dq, dk):
|
|
266
|
+
"""
|
|
267
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
268
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
269
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
270
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
271
|
+
"""
|
|
272
|
+
cos, sin = ctx.saved_tensors
|
|
273
|
+
mrope_section = ctx.mrope_section
|
|
274
|
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
|
275
|
+
return dq, dk, None, None, None, None
|