liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,266 @@
1
+ """
2
+ UB-aware GEGLU implementation for Ascend NPU.
3
+
4
+ This implementation automatically adjusts block sizes to fit within UB constraints,
5
+ preventing UB overflow errors when running on Ascend NPU.
6
+
7
+ It reuses the original kernels when possible, and only uses tiling when necessary.
8
+ """
9
+
10
+ import operator
11
+
12
+ import torch
13
+ import triton
14
+ import triton.language as tl
15
+
16
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
17
+ from liger_kernel.ops.utils import calculate_settings
18
+ from liger_kernel.ops.utils import compare_version
19
+ from liger_kernel.ops.utils import ensure_contiguous
20
+ from liger_kernel.utils import is_npu_available
21
+
22
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
23
+ try:
24
+ from triton.language.extra.libdevice import tanh
25
+ except ModuleNotFoundError:
26
+ from triton.language.extra.cuda.libdevice import tanh
27
+ else:
28
+ from triton.language.math import tanh
29
+
30
+
31
+ @triton.jit
32
+ def _geglu_tanh_forward_kernel_npu(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
33
+ """
34
+ UB-aware GEGLU forward kernel for NPU.
35
+
36
+ Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
37
+ When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
38
+ """
39
+ program_id = tl.program_id(0).to(tl.int64)
40
+
41
+ # locate start index
42
+ a += program_id * stride
43
+ b += program_id * stride
44
+ c += program_id * stride
45
+
46
+ # Process in tiles when BLOCK_SIZE < n_cols
47
+ for i in range(0, n_cols, BLOCK_SIZE):
48
+ col_offsets = i + tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
52
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
53
+
54
+ # tanh approximation form of GELU is computed with:
55
+ # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
56
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
57
+ a_cubed = a_row * a_row * a_row
58
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
59
+ tanh_result = tanh(tanh_arg)
60
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
61
+ c_row = geglu_a.cast(b_row.dtype) * b_row
62
+
63
+ tl.store(c + col_offsets, c_row, mask=mask)
64
+
65
+
66
+ @triton.jit
67
+ def _geglu_tanh_backward_kernel_npu(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
68
+ """
69
+ UB-aware GEGLU backward kernel for NPU.
70
+
71
+ Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
72
+ When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
73
+ """
74
+ program_id = tl.program_id(0).to(tl.int64)
75
+
76
+ # locate start index
77
+ dc += program_id * stride
78
+ a += program_id * stride
79
+ b += program_id * stride
80
+
81
+ # Process in tiles when BLOCK_SIZE < n_cols
82
+ for i in range(0, n_cols, BLOCK_SIZE):
83
+ col_offsets = i + tl.arange(0, BLOCK_SIZE)
84
+ mask = col_offsets < n_cols
85
+
86
+ dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
87
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
88
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
89
+
90
+ # recomputation to save memory
91
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
92
+ a_cubed = a_row * a_row * a_row
93
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
94
+ tanh_result = tanh(tanh_arg)
95
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
96
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
97
+
98
+ db_row = dc_row.cast(tl.float32) * geglu_a
99
+
100
+ # Gradient w.r.t. a can be computed with:
101
+ # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
102
+ # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
103
+ term1 = 0.5 * (1 + tanh_result)
104
+ tanh_sq = tanh_result * tanh_result
105
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
106
+ da_row = dc_row * b_row * (term1 + term2)
107
+
108
+ tl.store(a + col_offsets, da_row, mask=mask)
109
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
110
+
111
+
112
+ def geglu_forward(a, b):
113
+ """
114
+ UB-aware GEGLU forward pass for NPU.
115
+
116
+ Automatically adjusts block size to fit within UB constraints.
117
+ """
118
+ ori_shape = a.shape
119
+
120
+ n_cols = ori_shape[-1]
121
+ a = a.view(-1, n_cols)
122
+ b = b.view(-1, n_cols)
123
+ c = torch.empty_like(a)
124
+ n_rows = a.shape[0]
125
+
126
+ # Calculate desired block size
127
+ desired_block_size, num_warps = calculate_settings(n_cols)
128
+
129
+ # Compute tiling strategy based on UB capacity
130
+ dtype_size = a.element_size()
131
+ # GEGLU forward tiling strategy:
132
+ # - Calculates maximum safe block size based on UB capacity
133
+ # - Memory analysis (only buffers that occupy UB, excluding temporary variables):
134
+ # * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
135
+ # * Output: c_row (dtype_size bytes)
136
+ # * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
137
+ # and don't occupy UB since they are only used once
138
+ # * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
139
+ # * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
140
+ # - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
141
+ # - shapes: ((n_cols,),)
142
+ # - tiling_dims: (0,) means first dimension can be tiled
143
+ # - Returns: ((block_size,),)
144
+ shapes = ((n_cols,),)
145
+ if dtype_size == 2:
146
+ memory_multiplier = 4.0
147
+ else:
148
+ memory_multiplier = 3.0
149
+ tile_shapes = compute_default_tiling_strategy(
150
+ safety_margin=0.80,
151
+ dtype_size=dtype_size,
152
+ memory_multiplier=memory_multiplier,
153
+ shapes=shapes,
154
+ tiling_dims=(0,),
155
+ )
156
+
157
+ if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
158
+ # Strategy returns ((block_size,),)
159
+ adjusted_block_size = tile_shapes[0][0]
160
+ else:
161
+ # Fallback to desired block size if no best practice found (no tiling needed)
162
+ adjusted_block_size = desired_block_size
163
+ # Always use the unified NPU kernel
164
+ # When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
165
+ # When adjusted_block_size < n_cols, the loop handles tiling automatically
166
+ _geglu_tanh_forward_kernel_npu[(n_rows,)](
167
+ a,
168
+ b,
169
+ c,
170
+ c.stride(-2),
171
+ n_cols=n_cols,
172
+ BLOCK_SIZE=adjusted_block_size,
173
+ num_warps=num_warps,
174
+ )
175
+ return a, b, c.view(*ori_shape)
176
+
177
+
178
+ def geglu_backward(a, b, dc):
179
+ """
180
+ UB-aware GEGLU backward pass for NPU.
181
+
182
+ Automatically adjusts block size to fit within UB constraints.
183
+ """
184
+ ori_shape = dc.shape
185
+ n_cols = ori_shape[-1]
186
+ dc = dc.view(-1, n_cols)
187
+ n_rows = dc.shape[0]
188
+
189
+ # Calculate desired block size
190
+ desired_block_size, num_warps = calculate_settings(n_cols)
191
+
192
+ # Compute tiling strategy based on UB capacity
193
+ dtype_size = dc.element_size()
194
+ # GEGLU backward tiling strategy:
195
+ # - Calculates maximum safe block size based on UB capacity
196
+ # - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
197
+ # At this point, the following buffers simultaneously occupy UB:
198
+ # 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
199
+ # 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
200
+ # 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
201
+ # 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
202
+ # 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
203
+ # 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
204
+ # Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
205
+ # Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
206
+ # and don't occupy UB since they are only used once
207
+ # * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
208
+ # = 20 bytes/element, ratio = 20/2 = 10.0
209
+ # * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
210
+ # = 24 bytes/element, ratio = 24/4 = 6.0
211
+ # - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
212
+ # - shapes: ((n_cols,),)
213
+ # - tiling_dims: (0,) means first dimension can be tiled
214
+ # - Returns: ((block_size,),)
215
+ shapes = ((n_cols,),)
216
+ if dtype_size == 2:
217
+ memory_multiplier = 10.0
218
+ else:
219
+ memory_multiplier = 6.0
220
+ tile_shapes = compute_default_tiling_strategy(
221
+ safety_margin=0.80,
222
+ dtype_size=dtype_size,
223
+ memory_multiplier=memory_multiplier,
224
+ shapes=shapes,
225
+ tiling_dims=(0,),
226
+ )
227
+
228
+ if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
229
+ # Strategy returns ((block_size,),)
230
+ adjusted_block_size = tile_shapes[0][0]
231
+ else:
232
+ # Fallback to desired block size if no best practice found (no tiling needed)
233
+ adjusted_block_size = desired_block_size
234
+
235
+ # Always use the unified NPU kernel
236
+ # When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
237
+ # When adjusted_block_size < n_cols, the loop handles tiling automatically
238
+ _geglu_tanh_backward_kernel_npu[(n_rows,)](
239
+ dc,
240
+ a,
241
+ b,
242
+ dc.stride(-2),
243
+ n_cols=n_cols,
244
+ BLOCK_SIZE=adjusted_block_size,
245
+ num_warps=num_warps,
246
+ )
247
+
248
+ return a.view(*ori_shape), b.view(*ori_shape)
249
+
250
+
251
+ class LigerGELUMulFunction(torch.autograd.Function):
252
+ """UB-aware GEGLU function for Ascend NPU."""
253
+
254
+ @staticmethod
255
+ @ensure_contiguous
256
+ def forward(ctx, a, b):
257
+ a, b, c = geglu_forward(a, b)
258
+ ctx.save_for_backward(a, b)
259
+ return c
260
+
261
+ @staticmethod
262
+ @ensure_contiguous
263
+ def backward(ctx, dc):
264
+ a, b = ctx.saved_tensors
265
+ a, b = geglu_backward(a, b, dc)
266
+ return a, b
@@ -0,0 +1,285 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+
7
+
8
+ @triton.jit
9
+ def _triton_qwen2vl_mrope_npu(
10
+ q_ptr,
11
+ q_row_stride,
12
+ k_ptr,
13
+ k_row_stride,
14
+ cos,
15
+ sin,
16
+ sl,
17
+ bs: tl.constexpr,
18
+ n_qh: tl.constexpr,
19
+ n_kh: tl.constexpr,
20
+ hd: tl.constexpr,
21
+ mrope_section_t: tl.constexpr,
22
+ mrope_section_h: tl.constexpr,
23
+ BLOCK_Q: tl.constexpr,
24
+ BLOCK_K: tl.constexpr,
25
+ BACKWARD_PASS: tl.constexpr = False,
26
+ ):
27
+ pid = tl.program_id(0).to(tl.int64)
28
+
29
+ t_end = mrope_section_t
30
+ h_end = t_end + mrope_section_h
31
+
32
+ t_cos = cos + pid * hd
33
+ h_cos = t_cos + bs * sl * hd
34
+ w_cos = h_cos + bs * sl * hd
35
+ t_sin = sin + pid * hd
36
+ h_sin = t_sin + bs * sl * hd
37
+ w_sin = h_sin + bs * sl * hd
38
+
39
+ q_base = q_ptr + pid * q_row_stride
40
+ k_base = k_ptr + pid * k_row_stride
41
+
42
+ d_idx = tl.arange(0, hd // 2)
43
+ d_mask = d_idx < (hd // 2)
44
+
45
+ pos_mask_t = d_idx < t_end
46
+ pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
47
+
48
+ text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
49
+ text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
50
+ height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
51
+ height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
52
+ width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
53
+ width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
54
+
55
+ cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
56
+ sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
57
+
58
+ for qh_block in range(0, n_qh, BLOCK_Q):
59
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
60
+ qh_mask = qh_idx < n_qh
61
+
62
+ block_mask = qh_mask[:, None] & d_mask[None, :]
63
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
64
+
65
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
66
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
67
+
68
+ if not BACKWARD_PASS:
69
+ new_left = q_left * cos_vals - q_right * sin_vals
70
+ new_right = q_right * cos_vals + q_left * sin_vals
71
+ else:
72
+ new_left = q_left * cos_vals + q_right * sin_vals
73
+ new_right = q_right * cos_vals - q_left * sin_vals
74
+
75
+ tl.store(q_base + offsets, new_left, mask=block_mask)
76
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
77
+
78
+ for kh_block in range(0, n_kh, BLOCK_K):
79
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
80
+ kh_mask = kh_idx < n_kh
81
+
82
+ block_mask = kh_mask[:, None] & d_mask[None, :]
83
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
84
+
85
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
87
+
88
+ if not BACKWARD_PASS:
89
+ new_left = k_left * cos_vals - k_right * sin_vals
90
+ new_right = k_right * cos_vals + k_left * sin_vals
91
+ else:
92
+ new_left = k_left * cos_vals + k_right * sin_vals
93
+ new_right = k_right * cos_vals - k_left * sin_vals
94
+
95
+ tl.store(k_base + offsets, new_left, mask=block_mask)
96
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
97
+
98
+
99
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
100
+ # transpose it back to the physical shape because Triton looks at the physical storage
101
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
102
+ q = q.transpose(1, 2)
103
+ k = k.transpose(1, 2)
104
+
105
+ batch_size, seq_len, n_q_head, head_dim = q.shape
106
+ n_kv_head = k.shape[2]
107
+ pad_hd = triton.next_power_of_2(head_dim)
108
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
109
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
110
+
111
+ n_row = batch_size * seq_len
112
+
113
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
114
+ q = q.contiguous()
115
+ k = k.contiguous()
116
+ cos = cos.contiguous()
117
+ sin = sin.contiguous()
118
+
119
+ # Compute tiling strategy based on UB capacity
120
+ dtype_size = q.element_size()
121
+ # MROPE forward tiling strategy:
122
+ # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
123
+ # - In q heads loop (peak memory):
124
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
125
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
126
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
127
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
128
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
129
+ # - In k heads loop (peak memory):
130
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
131
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
132
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
133
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
134
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
135
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
136
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
137
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
138
+ # - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
139
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
140
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
141
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
142
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
143
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
144
+ tile_shapes = compute_default_tiling_strategy(
145
+ safety_margin=0.90,
146
+ dtype_size=dtype_size,
147
+ memory_multiplier=3.0,
148
+ shapes=shapes,
149
+ tiling_dims=(0, 0),
150
+ )
151
+
152
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
153
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
154
+ q_tile_shape, k_tile_shape = tile_shapes
155
+ BLOCK_Q, _ = q_tile_shape
156
+ BLOCK_K, _ = k_tile_shape
157
+ else:
158
+ # Fallback to conservative defaults
159
+ BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
160
+ BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
161
+ _triton_qwen2vl_mrope_npu[(n_row,)](
162
+ q,
163
+ q.stride(1),
164
+ k,
165
+ k.stride(1),
166
+ cos,
167
+ sin,
168
+ seq_len,
169
+ batch_size,
170
+ n_q_head,
171
+ n_kv_head,
172
+ head_dim,
173
+ mrope_section[0],
174
+ mrope_section[1],
175
+ BLOCK_Q,
176
+ BLOCK_K,
177
+ BACKWARD_PASS=False,
178
+ )
179
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
180
+
181
+
182
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
183
+ dq = dq.transpose(1, 2)
184
+ dk = dk.transpose(1, 2)
185
+
186
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
187
+ n_kv_head = dk.shape[2]
188
+ pad_hd = triton.next_power_of_2(head_dim)
189
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
190
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
191
+
192
+ n_row = batch_size * seq_len
193
+
194
+ # ensure dq and dk are contiguous
195
+ dq = dq.contiguous()
196
+ dk = dk.contiguous()
197
+
198
+ # Compute tiling strategy based on UB capacity
199
+ dtype_size = dq.element_size()
200
+ # MROPE backward tiling strategy:
201
+ # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
202
+ # - In q heads loop (peak memory):
203
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
204
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
205
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
206
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
207
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
208
+ # - In k heads loop (peak memory):
209
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
210
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
211
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
212
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
213
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
214
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
215
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
216
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
217
+ # - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
218
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
219
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
220
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
221
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
222
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
223
+ tile_shapes = compute_default_tiling_strategy(
224
+ safety_margin=0.90,
225
+ dtype_size=dtype_size,
226
+ memory_multiplier=3.0,
227
+ shapes=shapes,
228
+ tiling_dims=(0, 0),
229
+ )
230
+
231
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
232
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
233
+ q_tile_shape, k_tile_shape = tile_shapes
234
+ BLOCK_Q, _ = q_tile_shape
235
+ BLOCK_K, _ = k_tile_shape
236
+ else:
237
+ # Fallback to conservative defaults
238
+ BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
239
+ BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
240
+ _triton_qwen2vl_mrope_npu[(n_row,)](
241
+ dq,
242
+ dq.stride(1),
243
+ dk,
244
+ dk.stride(1),
245
+ cos,
246
+ sin,
247
+ seq_len,
248
+ batch_size,
249
+ n_q_head,
250
+ n_kv_head,
251
+ head_dim,
252
+ mrope_section[0],
253
+ mrope_section[1],
254
+ BLOCK_Q,
255
+ BLOCK_K,
256
+ BACKWARD_PASS=True,
257
+ )
258
+ return dq.transpose(1, 2), dk.transpose(1, 2)
259
+
260
+
261
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
262
+ @staticmethod
263
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
264
+ """
265
+ q size: (bsz, n_q_head, seq_len, head_dim)
266
+ k size: (bsz, n_kv_head, seq_len, head_dim)
267
+ cos size: (3, bsz, seq_len, head_dim)
268
+ sin size: (3, bsz, seq_len, head_dim)
269
+ """
270
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
271
+ ctx.save_for_backward(cos, sin)
272
+ ctx.mrope_section = mrope_section
273
+ return q, k
274
+
275
+ def backward(ctx, dq, dk):
276
+ """
277
+ dq size: (bsz, n_q_head, seq_len, head_dim)
278
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
279
+ cos size: (3, bsz, seq_len, head_dim)
280
+ sin size: (3, bsz, seq_len, head_dim)
281
+ """
282
+ cos, sin = ctx.saved_tensors
283
+ mrope_section = ctx.mrope_section
284
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
285
+ return dq, dk, None, None, None, None