liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+
7
+
8
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
9
+ """
10
+ Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
11
+
12
+ Supports:
13
+ - complex freqs: (..., head_dim_half) complex -> real/imag
14
+ - packed freqs: (..., 2*head_dim_half) real -> split into real/imag
15
+ """
16
+ if freqs_cis.is_complex():
17
+ freqs_real = freqs_cis.real
18
+ freqs_imag = freqs_cis.imag
19
+ else:
20
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
21
+ freqs_real = freqs_cis[..., :head_dim_half]
22
+ freqs_imag = freqs_cis[..., head_dim_half:]
23
+ else:
24
+ raise ValueError(
25
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
26
+ f"expected last dim = {2 * head_dim_half}"
27
+ )
28
+
29
+ if freqs_real.shape[-1] != head_dim_half:
30
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
31
+
32
+ # Flatten leading dims -> (N, head_dim_half)
33
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
34
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
35
+
36
+ # Broadcast/slice to (seq_len, head_dim_half)
37
+ if freqs_real.shape[0] < seq_len:
38
+ if freqs_real.shape[0] == 1:
39
+ freqs_real = freqs_real.expand(seq_len, -1)
40
+ freqs_imag = freqs_imag.expand(seq_len, -1)
41
+ else:
42
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
43
+ elif freqs_real.shape[0] > seq_len:
44
+ freqs_real = freqs_real[:seq_len]
45
+ freqs_imag = freqs_imag[:seq_len]
46
+
47
+ return freqs_real, freqs_imag
48
+
49
+
50
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
51
+ # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
52
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
53
+
54
+ if k.dtype != q.dtype:
55
+ k = k.to(q.dtype)
56
+
57
+ q = q.to(compute_dtype).contiguous()
58
+ k = k.to(compute_dtype).contiguous()
59
+ freqs_real = freqs_real.to(compute_dtype).contiguous()
60
+ freqs_imag = freqs_imag.to(compute_dtype).contiguous()
61
+ return q, k, freqs_real, freqs_imag, compute_dtype
62
+
63
+
64
+ @triton.jit
65
+ def _triton_llama4_rope_npu(
66
+ q_ptr,
67
+ k_ptr,
68
+ freqs_real_ptr,
69
+ freqs_imag_ptr,
70
+ q_row_stride,
71
+ k_row_stride,
72
+ q_head_stride,
73
+ k_head_stride,
74
+ freqs_row_stride,
75
+ sl,
76
+ bs: tl.constexpr,
77
+ n_qh: tl.constexpr,
78
+ n_kh: tl.constexpr,
79
+ hd: tl.constexpr,
80
+ BLOCK_Q: tl.constexpr,
81
+ BLOCK_K: tl.constexpr,
82
+ imag_sign: tl.constexpr,
83
+ ):
84
+ """
85
+ Llama4 RoPE on Ascend NPU for interleaved complex layout:
86
+ - q/k shape: (bs, sl, n_heads, hd)
87
+ - last dim layout: [real0, imag0, real1, imag1, ...]
88
+ - freqs_real/imag: (sl, hd//2)
89
+ """
90
+ pid = tl.program_id(0).to(tl.int64)
91
+ batch_idx = pid // sl
92
+ seq_idx = pid % sl
93
+
94
+ if batch_idx >= bs:
95
+ return
96
+
97
+ q_base = q_ptr + pid * q_row_stride
98
+ k_base = k_ptr + pid * k_row_stride
99
+
100
+ freq_base = seq_idx * freqs_row_stride
101
+ hd_idx = tl.arange(0, hd)
102
+ hd_mask = hd_idx < (hd)
103
+
104
+ freq_idx = tl.arange(0, hd // 2)
105
+ freq_mask = freq_idx < (hd // 2)
106
+
107
+ freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
108
+ freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
109
+
110
+ # Q heads (chunked for UB)
111
+ for qh_block in range(0, n_qh, BLOCK_Q):
112
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
113
+ qh_mask = qh_idx < n_qh
114
+ block_mask = qh_mask[:, None] & hd_mask[None, :]
115
+
116
+ head_ptr = q_base + qh_idx[:, None] * q_head_stride
117
+
118
+ q_pair = tl.load(
119
+ head_ptr + hd_idx[None, :],
120
+ mask=block_mask,
121
+ other=0.0,
122
+ )
123
+ q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
124
+ q_real, q_imag = tl.split(q_pair)
125
+
126
+ new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
127
+ new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
128
+ new_q_pair = tl.interleave(new_real, new_imag)
129
+
130
+ tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask)
131
+
132
+ # K heads (chunked for UB)
133
+ for kh_block in range(0, n_kh, BLOCK_K):
134
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
135
+ kh_mask = kh_idx < n_kh
136
+ block_mask = kh_mask[:, None] & hd_mask[None, :]
137
+
138
+ head_ptr = k_base + kh_idx[:, None] * k_head_stride
139
+
140
+ k_pair = tl.load(
141
+ head_ptr + hd_idx[None, :],
142
+ mask=block_mask,
143
+ other=0.0,
144
+ )
145
+
146
+ k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
147
+ k_real, k_imag = tl.split(k_pair)
148
+
149
+ new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
150
+ new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
151
+ new_k_pair = tl.interleave(new_real, new_imag)
152
+
153
+ tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask)
154
+
155
+
156
+ def llama4_rope_forward(q, k, freqs_cis):
157
+ """
158
+ Ascend NPU implementation of Llama4 RoPE.
159
+
160
+ q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
161
+ freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
162
+ """
163
+ original_dtype = q.dtype
164
+
165
+ bs, sl, n_qh, hd = q.shape
166
+ _, _, n_kh, _ = k.shape
167
+ if hd % 2 != 0:
168
+ raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
169
+ hd_half = hd // 2
170
+
171
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
172
+ q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
173
+
174
+ # UB tiling strategy: tile heads dimension only
175
+ dtype_size = q.element_size()
176
+ shapes = ((n_qh, hd), (n_kh, hd))
177
+ tile_shapes = compute_default_tiling_strategy(
178
+ safety_margin=0.90,
179
+ dtype_size=dtype_size,
180
+ memory_multiplier=12.0,
181
+ shapes=shapes,
182
+ tiling_dims=(0, 0),
183
+ )
184
+
185
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
186
+ q_tile_shape, k_tile_shape = tile_shapes
187
+ BLOCK_Q, _ = q_tile_shape
188
+ BLOCK_K, _ = k_tile_shape
189
+ else:
190
+ BLOCK_Q = triton.next_power_of_2(n_qh)
191
+ BLOCK_K = triton.next_power_of_2(n_kh)
192
+
193
+ n_row = bs * sl
194
+
195
+ _triton_llama4_rope_npu[(n_row,)](
196
+ q,
197
+ k,
198
+ freqs_real,
199
+ freqs_imag,
200
+ q.stride(1),
201
+ k.stride(1),
202
+ q.stride(2),
203
+ k.stride(2),
204
+ freqs_real.stride(0),
205
+ sl,
206
+ bs,
207
+ n_qh,
208
+ n_kh,
209
+ hd,
210
+ BLOCK_Q,
211
+ BLOCK_K,
212
+ imag_sign=1.0,
213
+ )
214
+
215
+ if compute_dtype != original_dtype:
216
+ q = q.to(original_dtype)
217
+ k = k.to(original_dtype)
218
+ return q, k
219
+
220
+
221
+ def llama4_rope_backward(dq, dk, freqs_cis):
222
+ """
223
+ Ascend NPU implementation of Llama4 RoPE.
224
+
225
+ q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
226
+ freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
227
+ """
228
+ original_dtype = dq.dtype
229
+
230
+ bs, sl, n_qh, hd = dq.shape
231
+ _, _, n_kh, _ = dk.shape
232
+ if hd % 2 != 0:
233
+ raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
234
+ hd_half = hd // 2
235
+
236
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
237
+ dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
238
+
239
+ # UB tiling strategy: tile heads dimension only
240
+ dtype_size = dq.element_size()
241
+ shapes = ((n_qh, hd), (n_kh, hd))
242
+ tile_shapes = compute_default_tiling_strategy(
243
+ safety_margin=0.90,
244
+ dtype_size=dtype_size,
245
+ memory_multiplier=12.0,
246
+ shapes=shapes,
247
+ tiling_dims=(0, 0),
248
+ )
249
+
250
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
251
+ q_tile_shape, k_tile_shape = tile_shapes
252
+ BLOCK_Q, _ = q_tile_shape
253
+ BLOCK_K, _ = k_tile_shape
254
+ else:
255
+ BLOCK_Q = triton.next_power_of_2(n_qh)
256
+ BLOCK_K = triton.next_power_of_2(n_kh)
257
+
258
+ n_row = bs * sl
259
+
260
+ _triton_llama4_rope_npu[(n_row,)](
261
+ dq,
262
+ dk,
263
+ freqs_real,
264
+ freqs_imag,
265
+ dq.stride(1),
266
+ dk.stride(1),
267
+ dq.stride(2),
268
+ dk.stride(2),
269
+ freqs_real.stride(0),
270
+ sl,
271
+ bs,
272
+ n_qh,
273
+ n_kh,
274
+ hd,
275
+ BLOCK_Q,
276
+ BLOCK_K,
277
+ imag_sign=-1.0,
278
+ )
279
+
280
+ if compute_dtype != original_dtype:
281
+ dq = dq.to(original_dtype)
282
+ dk = dk.to(original_dtype)
283
+ return dq, dk
284
+
285
+
286
+ class LigerLlama4RopeFunction(torch.autograd.Function):
287
+ @staticmethod
288
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
289
+ # BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
290
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
291
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
292
+ return q_out, k_out
293
+
294
+ @staticmethod
295
+ def backward(ctx, dq, dk):
296
+ (freqs_cis,) = ctx.saved_tensors
297
+ dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
298
+ return dq_out, dk_out, None, None
@@ -0,0 +1,275 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+
9
+ @triton.jit
10
+ def _triton_qwen2vl_mrope_npu(
11
+ q_ptr,
12
+ q_row_stride,
13
+ k_ptr,
14
+ k_row_stride,
15
+ cos,
16
+ sin,
17
+ sl,
18
+ bs: tl.constexpr,
19
+ total_rows: tl.constexpr,
20
+ n_qh: tl.constexpr,
21
+ n_kh: tl.constexpr,
22
+ hd: tl.constexpr,
23
+ mrope_section_t: tl.constexpr,
24
+ mrope_section_h: tl.constexpr,
25
+ BLOCK_Q: tl.constexpr,
26
+ BLOCK_K: tl.constexpr,
27
+ NUM_STAGES: tl.constexpr,
28
+ BACKWARD_PASS: tl.constexpr = False,
29
+ ):
30
+ program_id = tl.program_id(0)
31
+ num_programs = tl.num_programs(0)
32
+
33
+ rows_per_program = (total_rows + num_programs - 1) // num_programs
34
+ start_row = program_id * rows_per_program
35
+ actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
36
+
37
+ for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
38
+ pid = start_row + row_offset
39
+
40
+ t_end = mrope_section_t
41
+ h_end = t_end + mrope_section_h
42
+
43
+ t_cos = cos + pid * hd
44
+ h_cos = t_cos + bs * sl * hd
45
+ w_cos = h_cos + bs * sl * hd
46
+ t_sin = sin + pid * hd
47
+ h_sin = t_sin + bs * sl * hd
48
+ w_sin = h_sin + bs * sl * hd
49
+
50
+ q_base = q_ptr + pid * q_row_stride
51
+ k_base = k_ptr + pid * k_row_stride
52
+
53
+ d_idx = tl.arange(0, hd // 2)
54
+ d_mask = d_idx < (hd // 2)
55
+
56
+ pos_mask_t = d_idx < t_end
57
+ pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
58
+
59
+ text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
60
+ text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
61
+ height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
62
+ height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
63
+ width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
64
+ width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
65
+
66
+ cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
67
+ sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
68
+
69
+ # Process q heads in chunks to prevent UB overflow
70
+ for qh_block in range(0, n_qh, BLOCK_Q):
71
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
72
+ qh_mask = qh_idx < n_qh
73
+
74
+ block_mask = qh_mask[:, None] & d_mask[None, :]
75
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
76
+
77
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
78
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
79
+
80
+ if not BACKWARD_PASS:
81
+ new_left = q_left * cos_vals - q_right * sin_vals
82
+ new_right = q_right * cos_vals + q_left * sin_vals
83
+ else:
84
+ new_left = q_left * cos_vals + q_right * sin_vals
85
+ new_right = q_right * cos_vals - q_left * sin_vals
86
+
87
+ tl.store(q_base + offsets, new_left, mask=block_mask)
88
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
89
+
90
+ # Process k heads in chunks to prevent UB overflow
91
+ for kh_block in range(0, n_kh, BLOCK_K):
92
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
93
+ kh_mask = kh_idx < n_kh
94
+
95
+ block_mask = kh_mask[:, None] & d_mask[None, :]
96
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
97
+
98
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
99
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
100
+
101
+ if not BACKWARD_PASS:
102
+ new_left = k_left * cos_vals - k_right * sin_vals
103
+ new_right = k_right * cos_vals + k_left * sin_vals
104
+ else:
105
+ new_left = k_left * cos_vals + k_right * sin_vals
106
+ new_right = k_right * cos_vals - k_left * sin_vals
107
+
108
+ tl.store(k_base + offsets, new_left, mask=block_mask)
109
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
110
+
111
+
112
+ def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
113
+ # MROPE forward tiling strategy:
114
+ # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
115
+ # - In q heads loop (peak memory):
116
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
117
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
118
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
119
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
120
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
121
+ # - In k heads loop (peak memory):
122
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
123
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
124
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
125
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
126
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
127
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
128
+ # - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
129
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
130
+ # - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits
131
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
132
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
133
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
134
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
135
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
136
+ tile_shapes = compute_default_tiling_strategy(
137
+ safety_margin=0.90,
138
+ dtype_size=dtype_size,
139
+ memory_multiplier=3.0,
140
+ shapes=shapes,
141
+ tiling_dims=(0, 0),
142
+ )
143
+
144
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
145
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
146
+ q_tile_shape, k_tile_shape = tile_shapes
147
+ BLOCK_Q, _ = q_tile_shape
148
+ BLOCK_K, _ = k_tile_shape
149
+ else:
150
+ # Fallback to conservative defaults
151
+ BLOCK_Q = 2048
152
+ BLOCK_K = 2048
153
+
154
+ return BLOCK_Q, BLOCK_K
155
+
156
+
157
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
158
+ # transpose it back to the physical shape because Triton looks at the physical storage
159
+ q = q.transpose(1, 2)
160
+ k = k.transpose(1, 2)
161
+
162
+ batch_size, seq_len, n_q_head, head_dim = q.shape
163
+ n_kv_head = k.shape[2]
164
+ pad_hd = triton.next_power_of_2(head_dim)
165
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
166
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
167
+
168
+ n_row = batch_size * seq_len
169
+
170
+ # ensure tensors passed into the kernel are contiguous
171
+ q = q.contiguous()
172
+ k = k.contiguous()
173
+ cos = cos.contiguous()
174
+ sin = sin.contiguous()
175
+
176
+ dtype_size = q.element_size()
177
+ BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
178
+
179
+ num_cores = get_npu_core_count()
180
+ grid_size = min(num_cores, n_row)
181
+
182
+ _triton_qwen2vl_mrope_npu[(grid_size,)](
183
+ q,
184
+ q.stride(1),
185
+ k,
186
+ k.stride(1),
187
+ cos,
188
+ sin,
189
+ seq_len,
190
+ batch_size,
191
+ n_row,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ mrope_section[0],
196
+ mrope_section[1],
197
+ BLOCK_Q,
198
+ BLOCK_K,
199
+ NUM_STAGES=3,
200
+ BACKWARD_PASS=False,
201
+ )
202
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
203
+
204
+
205
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
206
+ dq = dq.transpose(1, 2)
207
+ dk = dk.transpose(1, 2)
208
+
209
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
210
+ n_kv_head = dk.shape[2]
211
+ pad_hd = triton.next_power_of_2(head_dim)
212
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
213
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
214
+
215
+ n_row = batch_size * seq_len
216
+
217
+ # ensure dq and dk are contiguous
218
+ dq = dq.contiguous()
219
+ dk = dk.contiguous()
220
+
221
+ dtype_size = dq.element_size()
222
+ BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
223
+
224
+ num_cores = get_npu_core_count()
225
+ grid_size = min(num_cores, n_row)
226
+
227
+ _triton_qwen2vl_mrope_npu[(grid_size,)](
228
+ dq,
229
+ dq.stride(1),
230
+ dk,
231
+ dk.stride(1),
232
+ cos,
233
+ sin,
234
+ seq_len,
235
+ batch_size,
236
+ n_row,
237
+ n_q_head,
238
+ n_kv_head,
239
+ head_dim,
240
+ mrope_section[0],
241
+ mrope_section[1],
242
+ BLOCK_Q,
243
+ BLOCK_K,
244
+ NUM_STAGES=3,
245
+ BACKWARD_PASS=True,
246
+ )
247
+ return dq.transpose(1, 2), dk.transpose(1, 2)
248
+
249
+
250
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
251
+ @staticmethod
252
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
253
+ """
254
+ q size: (bsz, n_q_head, seq_len, head_dim)
255
+ k size: (bsz, n_kv_head, seq_len, head_dim)
256
+ cos size: (3, bsz, seq_len, head_dim)
257
+ sin size: (3, bsz, seq_len, head_dim)
258
+ """
259
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
260
+ ctx.save_for_backward(cos, sin)
261
+ ctx.mrope_section = mrope_section
262
+ return q, k
263
+
264
+ @staticmethod
265
+ def backward(ctx, dq, dk):
266
+ """
267
+ dq size: (bsz, n_q_head, seq_len, head_dim)
268
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
269
+ cos size: (3, bsz, seq_len, head_dim)
270
+ sin size: (3, bsz, seq_len, head_dim)
271
+ """
272
+ cos, sin = ctx.saved_tensors
273
+ mrope_section = ctx.mrope_section
274
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
275
+ return dq, dk, None, None, None, None