liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  10. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  11. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  12. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  13. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  14. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +12 -3
  17. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  18. liger_kernel/ops/geglu.py +3 -2
  19. liger_kernel/ops/rms_norm.py +126 -49
  20. liger_kernel/ops/utils.py +12 -0
  21. liger_kernel/transformers/__init__.py +3 -0
  22. liger_kernel/transformers/auto_model.py +21 -0
  23. liger_kernel/transformers/cross_entropy.py +1 -1
  24. liger_kernel/transformers/dyt.py +1 -1
  25. liger_kernel/transformers/experimental/embedding.py +1 -1
  26. liger_kernel/transformers/functional.py +20 -20
  27. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  29. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  30. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  31. liger_kernel/transformers/geglu.py +1 -1
  32. liger_kernel/transformers/group_norm.py +1 -1
  33. liger_kernel/transformers/grpo_loss.py +1 -1
  34. liger_kernel/transformers/jsd.py +1 -1
  35. liger_kernel/transformers/kl_div.py +1 -1
  36. liger_kernel/transformers/layer_norm.py +1 -1
  37. liger_kernel/transformers/llama4_rope.py +1 -1
  38. liger_kernel/transformers/model/gemma3.py +1 -0
  39. liger_kernel/transformers/model/gpt_oss.py +211 -0
  40. liger_kernel/transformers/model/paligemma.py +1 -0
  41. liger_kernel/transformers/monkey_patch.py +118 -39
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +1 -1
  44. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  45. liger_kernel/transformers/rms_norm.py +8 -3
  46. liger_kernel/transformers/rope.py +28 -27
  47. liger_kernel/transformers/softmax.py +1 -1
  48. liger_kernel/transformers/sparsemax.py +1 -1
  49. liger_kernel/transformers/swiglu.py +1 -1
  50. liger_kernel/transformers/tiled_mlp.py +3 -3
  51. liger_kernel/transformers/tvd.py +1 -1
  52. liger_kernel/utils.py +27 -0
  53. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +9 -3
  54. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/RECORD +58 -46
  55. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  56. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  57. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  58. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,290 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+
7
+
8
+ @triton.jit
9
+ def _triton_rope_npu(
10
+ q_ptr,
11
+ q_row_stride,
12
+ k_ptr,
13
+ k_row_stride,
14
+ cos,
15
+ cos_row_stride,
16
+ sin,
17
+ sin_row_stride,
18
+ sl,
19
+ bs: tl.constexpr,
20
+ cos_bs: tl.constexpr,
21
+ n_qh: tl.constexpr,
22
+ n_kh: tl.constexpr,
23
+ hd: tl.constexpr,
24
+ BLOCK_Q: tl.constexpr,
25
+ BLOCK_K: tl.constexpr,
26
+ BACKWARD_PASS: tl.constexpr = False,
27
+ ):
28
+ pid = tl.program_id(0).to(tl.int64)
29
+ batch_idx = pid // sl
30
+ cos_row_idx = pid % sl
31
+
32
+ cos = cos + tl.where(
33
+ cos_bs == 1,
34
+ cos_row_idx * cos_row_stride,
35
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
36
+ )
37
+ sin = sin + tl.where(
38
+ cos_bs == 1,
39
+ cos_row_idx * sin_row_stride,
40
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
41
+ )
42
+
43
+ q_base = q_ptr + pid * q_row_stride
44
+ k_base = k_ptr + pid * k_row_stride
45
+
46
+ # Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads)
47
+ d_idx = tl.arange(0, hd // 2)
48
+ d_mask = d_idx < (hd // 2) # Always True, but kept for clarity
49
+ cos_vals = tl.load(cos + d_idx, mask=d_mask, other=0)
50
+ sin_vals = tl.load(sin + d_idx, mask=d_mask, other=0)
51
+
52
+ # Process q heads in chunks to prevent UB overflow
53
+ for qh_block in range(0, n_qh, BLOCK_Q):
54
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
55
+ qh_mask = qh_idx < n_qh
56
+
57
+ # block_mask: qh_mask broadcasted over d_idx dimension
58
+ block_mask = qh_mask[:, None]
59
+
60
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
61
+
62
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
63
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
64
+
65
+ if not BACKWARD_PASS:
66
+ new_left = q_left * cos_vals - q_right * sin_vals
67
+ new_right = q_right * cos_vals + q_left * sin_vals
68
+ else:
69
+ new_left = q_left * cos_vals + q_right * sin_vals
70
+ new_right = q_right * cos_vals - q_left * sin_vals
71
+
72
+ tl.store(q_base + offsets, new_left, mask=block_mask)
73
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
74
+
75
+ # Process k heads in chunks to prevent UB overflow
76
+ for kh_block in range(0, n_kh, BLOCK_K):
77
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
78
+ kh_mask = kh_idx < n_kh
79
+
80
+ # block_mask: kh_mask broadcasted over d_idx dimension
81
+ block_mask = kh_mask[:, None]
82
+
83
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
84
+
85
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
87
+
88
+ if not BACKWARD_PASS:
89
+ new_left = k_left * cos_vals - k_right * sin_vals
90
+ new_right = k_right * cos_vals + k_left * sin_vals
91
+ else:
92
+ new_left = k_left * cos_vals + k_right * sin_vals
93
+ new_right = k_right * cos_vals - k_left * sin_vals
94
+
95
+ tl.store(k_base + offsets, new_left, mask=block_mask)
96
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
97
+
98
+
99
+ def rope_forward(q, k, cos, sin):
100
+ # transpose it back to the physical shape because Triton looks at the physical storage
101
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
102
+ q = q.transpose(1, 2)
103
+ k = k.transpose(1, 2)
104
+
105
+ batch_size, seq_len, n_q_head, head_dim = q.shape
106
+ n_kv_head = k.shape[2]
107
+ pad_hd = triton.next_power_of_2(head_dim)
108
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
109
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
110
+
111
+ n_row = batch_size * seq_len
112
+
113
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
114
+ q = q.contiguous()
115
+ k = k.contiguous()
116
+ cos = cos.contiguous()
117
+ sin = sin.contiguous()
118
+ cos_batch_size = cos.shape[0]
119
+
120
+ # Compute tiling strategy based on UB capacity
121
+ dtype_size = q.element_size()
122
+ # ROPE forward tiling strategy (based on optimized ROPE kernel):
123
+ # - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
124
+ # - In q heads loop (peak memory):
125
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
126
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
127
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
128
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
129
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
130
+ # - In k heads loop (peak memory):
131
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
132
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
133
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
134
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
135
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
136
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
137
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
138
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
139
+ # - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
140
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
141
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
142
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
143
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
144
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
145
+ tile_shapes = compute_default_tiling_strategy(
146
+ safety_margin=0.90,
147
+ dtype_size=dtype_size,
148
+ memory_multiplier=3.0,
149
+ shapes=shapes,
150
+ tiling_dims=(0, 0),
151
+ )
152
+
153
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
154
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
155
+ q_tile_shape, k_tile_shape = tile_shapes
156
+ BLOCK_Q, _ = q_tile_shape
157
+ BLOCK_K, _ = k_tile_shape
158
+ else:
159
+ # Fallback to conservative defaults
160
+ BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
161
+ BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
162
+
163
+ _triton_rope_npu[(n_row,)](
164
+ q,
165
+ q.stride(1),
166
+ k,
167
+ k.stride(1),
168
+ cos,
169
+ cos.stride(-2),
170
+ sin,
171
+ sin.stride(-2),
172
+ seq_len,
173
+ batch_size,
174
+ cos_batch_size,
175
+ n_q_head,
176
+ n_kv_head,
177
+ head_dim,
178
+ BLOCK_Q,
179
+ BLOCK_K,
180
+ BACKWARD_PASS=False,
181
+ )
182
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
183
+
184
+
185
+ def rope_backward(dq, dk, cos, sin):
186
+ dq = dq.transpose(1, 2)
187
+ dk = dk.transpose(1, 2)
188
+
189
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
190
+ cos_batch_size = cos.shape[0]
191
+ n_kv_head = dk.shape[2]
192
+ pad_hd = triton.next_power_of_2(head_dim)
193
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
194
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
195
+
196
+ n_row = batch_size * seq_len
197
+
198
+ # ensure dq and dk are contiguous
199
+ dq = dq.contiguous()
200
+ dk = dk.contiguous()
201
+
202
+ # Compute tiling strategy based on UB capacity
203
+ dtype_size = dq.element_size()
204
+ # ROPE backward tiling strategy (based on optimized ROPE kernel):
205
+ # - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
206
+ # - In q heads loop (peak memory):
207
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
208
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
209
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
210
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
211
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
212
+ # - In k heads loop (peak memory):
213
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
214
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
215
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
216
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
217
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
218
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
219
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
220
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
221
+ # - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
222
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
223
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
224
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
225
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
226
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
227
+ tile_shapes = compute_default_tiling_strategy(
228
+ safety_margin=0.90,
229
+ dtype_size=dtype_size,
230
+ memory_multiplier=3.0,
231
+ shapes=shapes,
232
+ tiling_dims=(0, 0),
233
+ )
234
+
235
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
236
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
237
+ q_tile_shape, k_tile_shape = tile_shapes
238
+ BLOCK_Q, _ = q_tile_shape
239
+ BLOCK_K, _ = k_tile_shape
240
+ else:
241
+ # Fallback to conservative defaults
242
+ BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
243
+ BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
244
+
245
+ _triton_rope_npu[(n_row,)](
246
+ dq,
247
+ dq.stride(1),
248
+ dk,
249
+ dk.stride(1),
250
+ cos,
251
+ cos.stride(-2),
252
+ sin,
253
+ sin.stride(-2),
254
+ seq_len,
255
+ batch_size,
256
+ cos_batch_size,
257
+ n_q_head,
258
+ n_kv_head,
259
+ head_dim,
260
+ BLOCK_Q,
261
+ BLOCK_K,
262
+ BACKWARD_PASS=True,
263
+ )
264
+ return dq.transpose(1, 2), dk.transpose(1, 2)
265
+
266
+
267
+ class LigerRopeFunction(torch.autograd.Function):
268
+ @staticmethod
269
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
270
+ """
271
+ q size: (bsz, n_q_head, seq_len, head_dim)
272
+ k size: (bsz, n_kv_head, seq_len, head_dim)
273
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
274
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
275
+ """
276
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
277
+ ctx.save_for_backward(cos, sin)
278
+ return q, k
279
+
280
+ def backward(ctx, dq, dk):
281
+ """
282
+ dq size: (bsz, n_q_head, seq_len, head_dim)
283
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
284
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
285
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
286
+ """
287
+
288
+ cos, sin = ctx.saved_tensors
289
+ dq, dk = rope_backward(dq, dk, cos, sin)
290
+ return dq, dk, None, None, None, None
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # Kernels (High-performance 1D Flatten Implementation)
10
+ # -----------------------------------------------------------------------------
11
+
12
+
13
+ @triton.jit
14
+ def _swiglu_forward_kernel_flat(
15
+ a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
16
+ ):
17
+ pid = tl.program_id(0)
18
+ num_progs = tl.num_programs(0)
19
+
20
+ # Grid-Stride Loop
21
+ start_idx = pid * BLOCK_SIZE
22
+ stride = num_progs * BLOCK_SIZE
23
+
24
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
25
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < total_elements
27
+
28
+ a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
29
+ b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
30
+ res = (a_val * tl.sigmoid(a_val)) * b_val
31
+ tl.store(c_ptr + offsets, res, mask=mask)
32
+
33
+
34
+ @triton.jit
35
+ def _swiglu_backward_kernel_flat(
36
+ dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
37
+ ):
38
+ pid = tl.program_id(0)
39
+ num_progs = tl.num_programs(0)
40
+ start_idx = pid * BLOCK_SIZE
41
+ stride = num_progs * BLOCK_SIZE
42
+
43
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
44
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
45
+ mask = offsets < total_elements
46
+
47
+ dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
48
+ a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
49
+ b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
50
+
51
+ sig_a = tl.sigmoid(a)
52
+ silu_a = a * sig_a
53
+ term1 = silu_a * (1.0 - sig_a) + sig_a
54
+
55
+ db = dc * silu_a
56
+ da = dc * b * term1
57
+
58
+ tl.store(da_ptr + offsets, da, mask=mask)
59
+ tl.store(db_ptr + offsets, db, mask=mask)
60
+
61
+
62
+ # -----------------------------------------------------------------------------
63
+ # Helper: Call compute_default_tiling_strategy
64
+ # -----------------------------------------------------------------------------
65
+
66
+
67
+ def get_optimal_block_size(total_elements, is_backward=False):
68
+ """
69
+ Calculate optimal Block Size using compute_default_tiling_strategy
70
+ """
71
+ # 1. Set Memory Multiplier
72
+ # Forward is lighter, Backward requires more memory for intermediate variables
73
+ # 8.0 and 12.0 are empirical values based on 910B UB (192KB)
74
+ multiplier = 12.0 if is_backward else 8.0
75
+
76
+ # 2. Call calculation function
77
+ # Treat input as 1D (total_elements,), only tiling on dim 0
78
+ tile_shapes = compute_default_tiling_strategy(
79
+ safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
80
+ )
81
+
82
+ # 3. Parse result
83
+ if tile_shapes and len(tile_shapes) > 0:
84
+ block_size = tile_shapes[0][0]
85
+ return max(256, block_size)
86
+ else:
87
+ return 2048
88
+
89
+
90
+ def swiglu_forward(a, b):
91
+ if not a.is_contiguous():
92
+ a = a.contiguous()
93
+ if not b.is_contiguous():
94
+ b = b.contiguous()
95
+
96
+ total_elements = a.numel()
97
+ c = torch.empty_like(a)
98
+
99
+ block_size = get_optimal_block_size(total_elements, is_backward=False)
100
+
101
+ num_cores = get_npu_core_count()
102
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
103
+
104
+ _swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
105
+ return c
106
+
107
+
108
+ def swiglu_backward(a, b, dc):
109
+ if not dc.is_contiguous():
110
+ dc = dc.contiguous()
111
+ if not a.is_contiguous():
112
+ a = a.contiguous()
113
+ if not b.is_contiguous():
114
+ b = b.contiguous()
115
+
116
+ total_elements = dc.numel()
117
+ grad_a = torch.empty_like(a)
118
+ grad_b = torch.empty_like(b)
119
+
120
+ block_size = get_optimal_block_size(total_elements, is_backward=True)
121
+
122
+ num_cores = get_npu_core_count()
123
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
124
+
125
+ _swiglu_backward_kernel_flat[(grid_size,)](
126
+ dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
127
+ )
128
+ return grad_a, grad_b
129
+
130
+
131
+ class LigerSiLUMulFunction(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(ctx, a, b):
134
+ c = swiglu_forward(a, b)
135
+ ctx.save_for_backward(a, b)
136
+ return c
137
+
138
+ @staticmethod
139
+ def backward(ctx, dc):
140
+ a, b = ctx.saved_tensors
141
+ grad_a, grad_b = swiglu_backward(a, b, dc)
142
+ return grad_a, grad_b