liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,275 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+
9
+ @triton.jit
10
+ def _triton_qwen2vl_mrope_npu(
11
+ q_ptr,
12
+ q_row_stride,
13
+ k_ptr,
14
+ k_row_stride,
15
+ cos,
16
+ sin,
17
+ sl,
18
+ bs: tl.constexpr,
19
+ total_rows: tl.constexpr,
20
+ n_qh: tl.constexpr,
21
+ n_kh: tl.constexpr,
22
+ hd: tl.constexpr,
23
+ mrope_section_t: tl.constexpr,
24
+ mrope_section_h: tl.constexpr,
25
+ BLOCK_Q: tl.constexpr,
26
+ BLOCK_K: tl.constexpr,
27
+ NUM_STAGES: tl.constexpr,
28
+ BACKWARD_PASS: tl.constexpr = False,
29
+ ):
30
+ program_id = tl.program_id(0)
31
+ num_programs = tl.num_programs(0)
32
+
33
+ rows_per_program = (total_rows + num_programs - 1) // num_programs
34
+ start_row = program_id * rows_per_program
35
+ actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
36
+
37
+ for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
38
+ pid = start_row + row_offset
39
+
40
+ t_end = mrope_section_t
41
+ h_end = t_end + mrope_section_h
42
+
43
+ t_cos = cos + pid * hd
44
+ h_cos = t_cos + bs * sl * hd
45
+ w_cos = h_cos + bs * sl * hd
46
+ t_sin = sin + pid * hd
47
+ h_sin = t_sin + bs * sl * hd
48
+ w_sin = h_sin + bs * sl * hd
49
+
50
+ q_base = q_ptr + pid * q_row_stride
51
+ k_base = k_ptr + pid * k_row_stride
52
+
53
+ d_idx = tl.arange(0, hd // 2)
54
+ d_mask = d_idx < (hd // 2)
55
+
56
+ pos_mask_t = d_idx < t_end
57
+ pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
58
+
59
+ text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
60
+ text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
61
+ height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
62
+ height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
63
+ width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
64
+ width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
65
+
66
+ cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
67
+ sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
68
+
69
+ # Process q heads in chunks to prevent UB overflow
70
+ for qh_block in range(0, n_qh, BLOCK_Q):
71
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
72
+ qh_mask = qh_idx < n_qh
73
+
74
+ block_mask = qh_mask[:, None] & d_mask[None, :]
75
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
76
+
77
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
78
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
79
+
80
+ if not BACKWARD_PASS:
81
+ new_left = q_left * cos_vals - q_right * sin_vals
82
+ new_right = q_right * cos_vals + q_left * sin_vals
83
+ else:
84
+ new_left = q_left * cos_vals + q_right * sin_vals
85
+ new_right = q_right * cos_vals - q_left * sin_vals
86
+
87
+ tl.store(q_base + offsets, new_left, mask=block_mask)
88
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
89
+
90
+ # Process k heads in chunks to prevent UB overflow
91
+ for kh_block in range(0, n_kh, BLOCK_K):
92
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
93
+ kh_mask = kh_idx < n_kh
94
+
95
+ block_mask = kh_mask[:, None] & d_mask[None, :]
96
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
97
+
98
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
99
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
100
+
101
+ if not BACKWARD_PASS:
102
+ new_left = k_left * cos_vals - k_right * sin_vals
103
+ new_right = k_right * cos_vals + k_left * sin_vals
104
+ else:
105
+ new_left = k_left * cos_vals + k_right * sin_vals
106
+ new_right = k_right * cos_vals - k_left * sin_vals
107
+
108
+ tl.store(k_base + offsets, new_left, mask=block_mask)
109
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
110
+
111
+
112
+ def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
113
+ # MROPE forward tiling strategy:
114
+ # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
115
+ # - In q heads loop (peak memory):
116
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
117
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
118
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
119
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
120
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
121
+ # - In k heads loop (peak memory):
122
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
123
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
124
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
125
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
126
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
127
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
128
+ # - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
129
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
130
+ # - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits
131
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
132
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
133
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
134
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
135
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
136
+ tile_shapes = compute_default_tiling_strategy(
137
+ safety_margin=0.90,
138
+ dtype_size=dtype_size,
139
+ memory_multiplier=3.0,
140
+ shapes=shapes,
141
+ tiling_dims=(0, 0),
142
+ )
143
+
144
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
145
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
146
+ q_tile_shape, k_tile_shape = tile_shapes
147
+ BLOCK_Q, _ = q_tile_shape
148
+ BLOCK_K, _ = k_tile_shape
149
+ else:
150
+ # Fallback to conservative defaults
151
+ BLOCK_Q = 2048
152
+ BLOCK_K = 2048
153
+
154
+ return BLOCK_Q, BLOCK_K
155
+
156
+
157
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
158
+ # transpose it back to the physical shape because Triton looks at the physical storage
159
+ q = q.transpose(1, 2)
160
+ k = k.transpose(1, 2)
161
+
162
+ batch_size, seq_len, n_q_head, head_dim = q.shape
163
+ n_kv_head = k.shape[2]
164
+ pad_hd = triton.next_power_of_2(head_dim)
165
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
166
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
167
+
168
+ n_row = batch_size * seq_len
169
+
170
+ # ensure tensors passed into the kernel are contiguous
171
+ q = q.contiguous()
172
+ k = k.contiguous()
173
+ cos = cos.contiguous()
174
+ sin = sin.contiguous()
175
+
176
+ dtype_size = q.element_size()
177
+ BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
178
+
179
+ num_cores = get_npu_core_count()
180
+ grid_size = min(num_cores, n_row)
181
+
182
+ _triton_qwen2vl_mrope_npu[(grid_size,)](
183
+ q,
184
+ q.stride(1),
185
+ k,
186
+ k.stride(1),
187
+ cos,
188
+ sin,
189
+ seq_len,
190
+ batch_size,
191
+ n_row,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ mrope_section[0],
196
+ mrope_section[1],
197
+ BLOCK_Q,
198
+ BLOCK_K,
199
+ NUM_STAGES=3,
200
+ BACKWARD_PASS=False,
201
+ )
202
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
203
+
204
+
205
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
206
+ dq = dq.transpose(1, 2)
207
+ dk = dk.transpose(1, 2)
208
+
209
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
210
+ n_kv_head = dk.shape[2]
211
+ pad_hd = triton.next_power_of_2(head_dim)
212
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
213
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
214
+
215
+ n_row = batch_size * seq_len
216
+
217
+ # ensure dq and dk are contiguous
218
+ dq = dq.contiguous()
219
+ dk = dk.contiguous()
220
+
221
+ dtype_size = dq.element_size()
222
+ BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
223
+
224
+ num_cores = get_npu_core_count()
225
+ grid_size = min(num_cores, n_row)
226
+
227
+ _triton_qwen2vl_mrope_npu[(grid_size,)](
228
+ dq,
229
+ dq.stride(1),
230
+ dk,
231
+ dk.stride(1),
232
+ cos,
233
+ sin,
234
+ seq_len,
235
+ batch_size,
236
+ n_row,
237
+ n_q_head,
238
+ n_kv_head,
239
+ head_dim,
240
+ mrope_section[0],
241
+ mrope_section[1],
242
+ BLOCK_Q,
243
+ BLOCK_K,
244
+ NUM_STAGES=3,
245
+ BACKWARD_PASS=True,
246
+ )
247
+ return dq.transpose(1, 2), dk.transpose(1, 2)
248
+
249
+
250
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
251
+ @staticmethod
252
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
253
+ """
254
+ q size: (bsz, n_q_head, seq_len, head_dim)
255
+ k size: (bsz, n_kv_head, seq_len, head_dim)
256
+ cos size: (3, bsz, seq_len, head_dim)
257
+ sin size: (3, bsz, seq_len, head_dim)
258
+ """
259
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
260
+ ctx.save_for_backward(cos, sin)
261
+ ctx.mrope_section = mrope_section
262
+ return q, k
263
+
264
+ @staticmethod
265
+ def backward(ctx, dq, dk):
266
+ """
267
+ dq size: (bsz, n_q_head, seq_len, head_dim)
268
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
269
+ cos size: (3, bsz, seq_len, head_dim)
270
+ sin size: (3, bsz, seq_len, head_dim)
271
+ """
272
+ cos, sin = ctx.saved_tensors
273
+ mrope_section = ctx.mrope_section
274
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
275
+ return dq, dk, None, None, None, None
@@ -0,0 +1,265 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+
9
+ @triton.jit
10
+ def _triton_rope_npu(
11
+ q_ptr,
12
+ q_row_stride,
13
+ k_ptr,
14
+ k_row_stride,
15
+ cos,
16
+ cos_row_stride,
17
+ sin,
18
+ sin_row_stride,
19
+ sl,
20
+ total_rows: tl.constexpr,
21
+ cos_bs: tl.constexpr,
22
+ n_qh: tl.constexpr,
23
+ n_kh: tl.constexpr,
24
+ hd: tl.constexpr,
25
+ BLOCK_Q: tl.constexpr,
26
+ BLOCK_K: tl.constexpr,
27
+ NUM_STAGES: tl.constexpr,
28
+ BACKWARD_PASS: tl.constexpr = False,
29
+ ):
30
+ program_id = tl.program_id(0)
31
+ num_programs = tl.num_programs(0)
32
+
33
+ rows_per_program = (total_rows + num_programs - 1) // num_programs
34
+ start_row = program_id * rows_per_program
35
+ actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
36
+
37
+ for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
38
+ pid = start_row + row_offset
39
+
40
+ row_idx = pid % sl
41
+ cos_ptr = cos + tl.where(cos_bs == 1, row_idx * cos_row_stride, pid * cos_row_stride)
42
+ sin_ptr = sin + tl.where(cos_bs == 1, row_idx * sin_row_stride, pid * sin_row_stride)
43
+
44
+ # Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads)
45
+ d_idx = tl.arange(0, hd // 2)
46
+ d_mask = d_idx < (hd // 2) # Always True, but kept for clarity
47
+ cos_vals = tl.load(cos_ptr + d_idx, mask=d_mask, other=0)
48
+ sin_vals = tl.load(sin_ptr + d_idx, mask=d_mask, other=0)
49
+
50
+ # Process q heads in chunks to prevent UB overflow
51
+ for qh_block in range(0, n_qh, BLOCK_Q):
52
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
53
+ qh_mask = qh_idx < n_qh
54
+
55
+ # block_mask: qh_mask broadcasted over d_idx dimension
56
+ block_mask = qh_mask[:, None]
57
+
58
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
59
+ q_base = q_ptr + pid * q_row_stride
60
+
61
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
62
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
63
+
64
+ if not BACKWARD_PASS:
65
+ new_left = q_left * cos_vals - q_right * sin_vals
66
+ new_right = q_right * cos_vals + q_left * sin_vals
67
+ else:
68
+ new_left = q_left * cos_vals + q_right * sin_vals
69
+ new_right = q_right * cos_vals - q_left * sin_vals
70
+
71
+ tl.store(q_base + offsets, new_left, mask=block_mask)
72
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
73
+
74
+ # Process k heads in chunks to prevent UB overflow
75
+ for kh_block in range(0, n_kh, BLOCK_K):
76
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
77
+ kh_mask = kh_idx < n_kh
78
+
79
+ # block_mask: kh_mask broadcasted over d_idx dimension
80
+ block_mask = kh_mask[:, None]
81
+
82
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
83
+ k_base = k_ptr + pid * k_row_stride
84
+
85
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
87
+
88
+ if not BACKWARD_PASS:
89
+ new_left = k_left * cos_vals - k_right * sin_vals
90
+ new_right = k_right * cos_vals + k_left * sin_vals
91
+ else:
92
+ new_left = k_left * cos_vals + k_right * sin_vals
93
+ new_right = k_right * cos_vals - k_left * sin_vals
94
+
95
+ tl.store(k_base + offsets, new_left, mask=block_mask)
96
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
97
+
98
+
99
+ def get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
100
+ # Compute tiling strategy based on UB capacity
101
+ # ROPE forward tiling strategy (based on optimized ROPE kernel):
102
+ # - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
103
+ # - In q heads loop (peak memory):
104
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
105
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
106
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
107
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
108
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
109
+ # - In k heads loop (peak memory):
110
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
111
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
112
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
113
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
114
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
115
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
116
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
117
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
118
+ # - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
119
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
120
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
121
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
122
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
123
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
124
+ tile_shapes = compute_default_tiling_strategy(
125
+ safety_margin=0.90,
126
+ dtype_size=dtype_size,
127
+ memory_multiplier=3.0,
128
+ shapes=shapes,
129
+ tiling_dims=(0, 0),
130
+ )
131
+
132
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
133
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
134
+ q_tile_shape, k_tile_shape = tile_shapes
135
+ BLOCK_Q, _ = q_tile_shape
136
+ BLOCK_K, _ = k_tile_shape
137
+ else:
138
+ # Fallback to conservative defaults
139
+ BLOCK_Q = 2048
140
+ BLOCK_K = 2048
141
+
142
+ return BLOCK_Q, BLOCK_K
143
+
144
+
145
+ def rope_forward(q, k, cos, sin):
146
+ # transpose it back to the physical shape because Triton looks at the physical storage
147
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
148
+ q = q.transpose(1, 2)
149
+ k = k.transpose(1, 2)
150
+
151
+ batch_size, seq_len, n_q_head, head_dim = q.shape
152
+ n_kv_head = k.shape[2]
153
+ pad_hd = triton.next_power_of_2(head_dim)
154
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
155
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
156
+
157
+ n_row = batch_size * seq_len
158
+
159
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
160
+ q = q.contiguous()
161
+ k = k.contiguous()
162
+ cos = cos.contiguous()
163
+ sin = sin.contiguous()
164
+ cos_batch_size = cos.shape[0]
165
+
166
+ dtype_size = q.element_size()
167
+ BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
168
+
169
+ num_cores = get_npu_core_count()
170
+ grid_size = min(num_cores, n_row)
171
+
172
+ _triton_rope_npu[(grid_size,)](
173
+ q,
174
+ q.stride(1),
175
+ k,
176
+ k.stride(1),
177
+ cos,
178
+ cos.stride(-2),
179
+ sin,
180
+ sin.stride(-2),
181
+ seq_len,
182
+ n_row,
183
+ cos_batch_size,
184
+ n_q_head,
185
+ n_kv_head,
186
+ head_dim,
187
+ BLOCK_Q,
188
+ BLOCK_K,
189
+ NUM_STAGES=3,
190
+ BACKWARD_PASS=False,
191
+ )
192
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
193
+
194
+
195
+ def rope_backward(dq, dk, cos, sin):
196
+ dq = dq.transpose(1, 2)
197
+ dk = dk.transpose(1, 2)
198
+
199
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
200
+ cos_batch_size = cos.shape[0]
201
+ n_kv_head = dk.shape[2]
202
+ pad_hd = triton.next_power_of_2(head_dim)
203
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
204
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
205
+
206
+ n_row = batch_size * seq_len
207
+
208
+ # ensure dq and dk are contiguous
209
+ dq = dq.contiguous()
210
+ dk = dk.contiguous()
211
+
212
+ dtype_size = dq.element_size()
213
+ BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
214
+
215
+ num_cores = get_npu_core_count()
216
+ grid_size = min(num_cores, n_row)
217
+
218
+ _triton_rope_npu[(grid_size,)](
219
+ dq,
220
+ dq.stride(1),
221
+ dk,
222
+ dk.stride(1),
223
+ cos,
224
+ cos.stride(-2),
225
+ sin,
226
+ sin.stride(-2),
227
+ seq_len,
228
+ n_row,
229
+ cos_batch_size,
230
+ n_q_head,
231
+ n_kv_head,
232
+ head_dim,
233
+ BLOCK_Q,
234
+ BLOCK_K,
235
+ NUM_STAGES=3,
236
+ BACKWARD_PASS=True,
237
+ )
238
+ return dq.transpose(1, 2), dk.transpose(1, 2)
239
+
240
+
241
+ class LigerRopeFunction(torch.autograd.Function):
242
+ @staticmethod
243
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
244
+ """
245
+ q size: (bsz, n_q_head, seq_len, head_dim)
246
+ k size: (bsz, n_kv_head, seq_len, head_dim)
247
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
248
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
249
+ """
250
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
251
+ ctx.save_for_backward(cos, sin)
252
+ return q, k
253
+
254
+ @staticmethod
255
+ def backward(ctx, dq, dk):
256
+ """
257
+ dq size: (bsz, n_q_head, seq_len, head_dim)
258
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
259
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
260
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
261
+ """
262
+
263
+ cos, sin = ctx.saved_tensors
264
+ dq, dk = rope_backward(dq, dk, cos, sin)
265
+ return dq, dk, None, None, None, None
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # Kernels (High-performance 1D Flatten Implementation)
10
+ # -----------------------------------------------------------------------------
11
+
12
+
13
+ @triton.jit
14
+ def _swiglu_forward_kernel_flat(
15
+ a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
16
+ ):
17
+ pid = tl.program_id(0)
18
+ num_progs = tl.num_programs(0)
19
+
20
+ # Grid-Stride Loop
21
+ start_idx = pid * BLOCK_SIZE
22
+ stride = num_progs * BLOCK_SIZE
23
+
24
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
25
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < total_elements
27
+
28
+ a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
29
+ b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
30
+ res = (a_val * tl.sigmoid(a_val)) * b_val
31
+ tl.store(c_ptr + offsets, res, mask=mask)
32
+
33
+
34
+ @triton.jit
35
+ def _swiglu_backward_kernel_flat(
36
+ dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
37
+ ):
38
+ pid = tl.program_id(0)
39
+ num_progs = tl.num_programs(0)
40
+ start_idx = pid * BLOCK_SIZE
41
+ stride = num_progs * BLOCK_SIZE
42
+
43
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
44
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
45
+ mask = offsets < total_elements
46
+
47
+ dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
48
+ a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
49
+ b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
50
+
51
+ sig_a = tl.sigmoid(a)
52
+ silu_a = a * sig_a
53
+ term1 = silu_a * (1.0 - sig_a) + sig_a
54
+
55
+ db = dc * silu_a
56
+ da = dc * b * term1
57
+
58
+ tl.store(da_ptr + offsets, da, mask=mask)
59
+ tl.store(db_ptr + offsets, db, mask=mask)
60
+
61
+
62
+ # -----------------------------------------------------------------------------
63
+ # Helper: Call compute_default_tiling_strategy
64
+ # -----------------------------------------------------------------------------
65
+
66
+
67
+ def get_optimal_block_size(total_elements, is_backward=False):
68
+ """
69
+ Calculate optimal Block Size using compute_default_tiling_strategy
70
+ """
71
+ # 1. Set Memory Multiplier
72
+ # Forward is lighter, Backward requires more memory for intermediate variables
73
+ # 8.0 and 12.0 are empirical values based on 910B UB (192KB)
74
+ multiplier = 12.0 if is_backward else 8.0
75
+
76
+ # 2. Call calculation function
77
+ # Treat input as 1D (total_elements,), only tiling on dim 0
78
+ tile_shapes = compute_default_tiling_strategy(
79
+ safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
80
+ )
81
+
82
+ # 3. Parse result
83
+ if tile_shapes and len(tile_shapes) > 0:
84
+ block_size = tile_shapes[0][0]
85
+ return max(256, block_size)
86
+ else:
87
+ return 2048
88
+
89
+
90
+ def swiglu_forward(a, b):
91
+ if not a.is_contiguous():
92
+ a = a.contiguous()
93
+ if not b.is_contiguous():
94
+ b = b.contiguous()
95
+
96
+ total_elements = a.numel()
97
+ c = torch.empty_like(a)
98
+
99
+ block_size = get_optimal_block_size(total_elements, is_backward=False)
100
+
101
+ num_cores = get_npu_core_count()
102
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
103
+
104
+ _swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
105
+ return c
106
+
107
+
108
+ def swiglu_backward(a, b, dc):
109
+ if not dc.is_contiguous():
110
+ dc = dc.contiguous()
111
+ if not a.is_contiguous():
112
+ a = a.contiguous()
113
+ if not b.is_contiguous():
114
+ b = b.contiguous()
115
+
116
+ total_elements = dc.numel()
117
+ grad_a = torch.empty_like(a)
118
+ grad_b = torch.empty_like(b)
119
+
120
+ block_size = get_optimal_block_size(total_elements, is_backward=True)
121
+
122
+ num_cores = get_npu_core_count()
123
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
124
+
125
+ _swiglu_backward_kernel_flat[(grid_size,)](
126
+ dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
127
+ )
128
+ return grad_a, grad_b
129
+
130
+
131
+ class LigerSiLUMulFunction(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(ctx, a, b):
134
+ c = swiglu_forward(a, b)
135
+ ctx.save_for_backward(a, b)
136
+ return c
137
+
138
+ @staticmethod
139
+ def backward(ctx, dc):
140
+ a, b = ctx.saved_tensors
141
+ grad_a, grad_b = swiglu_backward(a, b, dc)
142
+ return grad_a, grad_b