liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+
9
+ @triton.jit
10
+ def _triton_rope_npu(
11
+ q_ptr,
12
+ q_row_stride,
13
+ k_ptr,
14
+ k_row_stride,
15
+ cos,
16
+ cos_row_stride,
17
+ sin,
18
+ sin_row_stride,
19
+ sl,
20
+ total_rows: tl.constexpr,
21
+ cos_bs: tl.constexpr,
22
+ n_qh: tl.constexpr,
23
+ n_kh: tl.constexpr,
24
+ hd: tl.constexpr,
25
+ BLOCK_Q: tl.constexpr,
26
+ BLOCK_K: tl.constexpr,
27
+ NUM_STAGES: tl.constexpr,
28
+ BACKWARD_PASS: tl.constexpr = False,
29
+ ):
30
+ program_id = tl.program_id(0)
31
+ num_programs = tl.num_programs(0)
32
+
33
+ rows_per_program = (total_rows + num_programs - 1) // num_programs
34
+ start_row = program_id * rows_per_program
35
+ actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
36
+
37
+ for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
38
+ pid = start_row + row_offset
39
+
40
+ row_idx = pid % sl
41
+ cos_ptr = cos + tl.where(cos_bs == 1, row_idx * cos_row_stride, pid * cos_row_stride)
42
+ sin_ptr = sin + tl.where(cos_bs == 1, row_idx * sin_row_stride, pid * sin_row_stride)
43
+
44
+ # Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads)
45
+ d_idx = tl.arange(0, hd // 2)
46
+ d_mask = d_idx < (hd // 2) # Always True, but kept for clarity
47
+ cos_vals = tl.load(cos_ptr + d_idx, mask=d_mask, other=0)
48
+ sin_vals = tl.load(sin_ptr + d_idx, mask=d_mask, other=0)
49
+
50
+ # Process q heads in chunks to prevent UB overflow
51
+ for qh_block in range(0, n_qh, BLOCK_Q):
52
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
53
+ qh_mask = qh_idx < n_qh
54
+
55
+ # block_mask: qh_mask broadcasted over d_idx dimension
56
+ block_mask = qh_mask[:, None]
57
+
58
+ offsets = qh_idx[:, None] * hd + d_idx[None, :]
59
+ q_base = q_ptr + pid * q_row_stride
60
+
61
+ q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
62
+ q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
63
+
64
+ if not BACKWARD_PASS:
65
+ new_left = q_left * cos_vals - q_right * sin_vals
66
+ new_right = q_right * cos_vals + q_left * sin_vals
67
+ else:
68
+ new_left = q_left * cos_vals + q_right * sin_vals
69
+ new_right = q_right * cos_vals - q_left * sin_vals
70
+
71
+ tl.store(q_base + offsets, new_left, mask=block_mask)
72
+ tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
73
+
74
+ # Process k heads in chunks to prevent UB overflow
75
+ for kh_block in range(0, n_kh, BLOCK_K):
76
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
77
+ kh_mask = kh_idx < n_kh
78
+
79
+ # block_mask: kh_mask broadcasted over d_idx dimension
80
+ block_mask = kh_mask[:, None]
81
+
82
+ offsets = kh_idx[:, None] * hd + d_idx[None, :]
83
+ k_base = k_ptr + pid * k_row_stride
84
+
85
+ k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86
+ k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
87
+
88
+ if not BACKWARD_PASS:
89
+ new_left = k_left * cos_vals - k_right * sin_vals
90
+ new_right = k_right * cos_vals + k_left * sin_vals
91
+ else:
92
+ new_left = k_left * cos_vals + k_right * sin_vals
93
+ new_right = k_right * cos_vals - k_left * sin_vals
94
+
95
+ tl.store(k_base + offsets, new_left, mask=block_mask)
96
+ tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
97
+
98
+
99
+ def get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
100
+ # Compute tiling strategy based on UB capacity
101
+ # ROPE forward tiling strategy (based on optimized ROPE kernel):
102
+ # - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
103
+ # - In q heads loop (peak memory):
104
+ # * q_left: BLOCK_Q * (pad_hd // 2) elements
105
+ # * q_right: BLOCK_Q * (pad_hd // 2) elements
106
+ # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
107
+ # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
108
+ # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
109
+ # - In k heads loop (peak memory):
110
+ # * k_left: BLOCK_K * (pad_hd // 2) elements
111
+ # * k_right: BLOCK_K * (pad_hd // 2) elements
112
+ # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
113
+ # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
114
+ # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
115
+ # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
116
+ # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
117
+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
118
+ # - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
119
+ # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
120
+ # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
121
+ # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
122
+ # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
123
+ shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
124
+ tile_shapes = compute_default_tiling_strategy(
125
+ safety_margin=0.90,
126
+ dtype_size=dtype_size,
127
+ memory_multiplier=3.0,
128
+ shapes=shapes,
129
+ tiling_dims=(0, 0),
130
+ )
131
+
132
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
133
+ # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
134
+ q_tile_shape, k_tile_shape = tile_shapes
135
+ BLOCK_Q, _ = q_tile_shape
136
+ BLOCK_K, _ = k_tile_shape
137
+ else:
138
+ # Fallback to conservative defaults
139
+ BLOCK_Q = 2048
140
+ BLOCK_K = 2048
141
+
142
+ return BLOCK_Q, BLOCK_K
143
+
144
+
145
+ def rope_forward(q, k, cos, sin):
146
+ # transpose it back to the physical shape because Triton looks at the physical storage
147
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
148
+ q = q.transpose(1, 2)
149
+ k = k.transpose(1, 2)
150
+
151
+ batch_size, seq_len, n_q_head, head_dim = q.shape
152
+ n_kv_head = k.shape[2]
153
+ pad_hd = triton.next_power_of_2(head_dim)
154
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
155
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
156
+
157
+ n_row = batch_size * seq_len
158
+
159
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
160
+ q = q.contiguous()
161
+ k = k.contiguous()
162
+ cos = cos.contiguous()
163
+ sin = sin.contiguous()
164
+ cos_batch_size = cos.shape[0]
165
+
166
+ dtype_size = q.element_size()
167
+ BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
168
+
169
+ num_cores = get_npu_core_count()
170
+ grid_size = min(num_cores, n_row)
171
+
172
+ _triton_rope_npu[(grid_size,)](
173
+ q,
174
+ q.stride(1),
175
+ k,
176
+ k.stride(1),
177
+ cos,
178
+ cos.stride(-2),
179
+ sin,
180
+ sin.stride(-2),
181
+ seq_len,
182
+ n_row,
183
+ cos_batch_size,
184
+ n_q_head,
185
+ n_kv_head,
186
+ head_dim,
187
+ BLOCK_Q,
188
+ BLOCK_K,
189
+ NUM_STAGES=3,
190
+ BACKWARD_PASS=False,
191
+ )
192
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
193
+
194
+
195
+ def rope_backward(dq, dk, cos, sin):
196
+ dq = dq.transpose(1, 2)
197
+ dk = dk.transpose(1, 2)
198
+
199
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
200
+ cos_batch_size = cos.shape[0]
201
+ n_kv_head = dk.shape[2]
202
+ pad_hd = triton.next_power_of_2(head_dim)
203
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
204
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
205
+
206
+ n_row = batch_size * seq_len
207
+
208
+ # ensure dq and dk are contiguous
209
+ dq = dq.contiguous()
210
+ dk = dk.contiguous()
211
+
212
+ dtype_size = dq.element_size()
213
+ BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
214
+
215
+ num_cores = get_npu_core_count()
216
+ grid_size = min(num_cores, n_row)
217
+
218
+ _triton_rope_npu[(grid_size,)](
219
+ dq,
220
+ dq.stride(1),
221
+ dk,
222
+ dk.stride(1),
223
+ cos,
224
+ cos.stride(-2),
225
+ sin,
226
+ sin.stride(-2),
227
+ seq_len,
228
+ n_row,
229
+ cos_batch_size,
230
+ n_q_head,
231
+ n_kv_head,
232
+ head_dim,
233
+ BLOCK_Q,
234
+ BLOCK_K,
235
+ NUM_STAGES=3,
236
+ BACKWARD_PASS=True,
237
+ )
238
+ return dq.transpose(1, 2), dk.transpose(1, 2)
239
+
240
+
241
+ class LigerRopeFunction(torch.autograd.Function):
242
+ @staticmethod
243
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
244
+ """
245
+ q size: (bsz, n_q_head, seq_len, head_dim)
246
+ k size: (bsz, n_kv_head, seq_len, head_dim)
247
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
248
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
249
+ """
250
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
251
+ ctx.save_for_backward(cos, sin)
252
+ return q, k
253
+
254
+ @staticmethod
255
+ def backward(ctx, dq, dk):
256
+ """
257
+ dq size: (bsz, n_q_head, seq_len, head_dim)
258
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
259
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
260
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
261
+ """
262
+
263
+ cos, sin = ctx.saved_tensors
264
+ dq, dk = rope_backward(dq, dk, cos, sin)
265
+ return dq, dk, None, None, None, None
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import get_npu_core_count
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # Kernels (High-performance 1D Flatten Implementation)
10
+ # -----------------------------------------------------------------------------
11
+
12
+
13
+ @triton.jit
14
+ def _swiglu_forward_kernel_flat(
15
+ a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
16
+ ):
17
+ pid = tl.program_id(0)
18
+ num_progs = tl.num_programs(0)
19
+
20
+ # Grid-Stride Loop
21
+ start_idx = pid * BLOCK_SIZE
22
+ stride = num_progs * BLOCK_SIZE
23
+
24
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
25
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < total_elements
27
+
28
+ a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
29
+ b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
30
+ res = (a_val * tl.sigmoid(a_val)) * b_val
31
+ tl.store(c_ptr + offsets, res, mask=mask)
32
+
33
+
34
+ @triton.jit
35
+ def _swiglu_backward_kernel_flat(
36
+ dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
37
+ ):
38
+ pid = tl.program_id(0)
39
+ num_progs = tl.num_programs(0)
40
+ start_idx = pid * BLOCK_SIZE
41
+ stride = num_progs * BLOCK_SIZE
42
+
43
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
44
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
45
+ mask = offsets < total_elements
46
+
47
+ dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
48
+ a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
49
+ b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
50
+
51
+ sig_a = tl.sigmoid(a)
52
+ silu_a = a * sig_a
53
+ term1 = silu_a * (1.0 - sig_a) + sig_a
54
+
55
+ db = dc * silu_a
56
+ da = dc * b * term1
57
+
58
+ tl.store(da_ptr + offsets, da, mask=mask)
59
+ tl.store(db_ptr + offsets, db, mask=mask)
60
+
61
+
62
+ # -----------------------------------------------------------------------------
63
+ # Helper: Call compute_default_tiling_strategy
64
+ # -----------------------------------------------------------------------------
65
+
66
+
67
+ def get_optimal_block_size(total_elements, is_backward=False):
68
+ """
69
+ Calculate optimal Block Size using compute_default_tiling_strategy
70
+ """
71
+ # 1. Set Memory Multiplier
72
+ # Forward is lighter, Backward requires more memory for intermediate variables
73
+ # 8.0 and 12.0 are empirical values based on 910B UB (192KB)
74
+ multiplier = 12.0 if is_backward else 8.0
75
+
76
+ # 2. Call calculation function
77
+ # Treat input as 1D (total_elements,), only tiling on dim 0
78
+ tile_shapes = compute_default_tiling_strategy(
79
+ safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
80
+ )
81
+
82
+ # 3. Parse result
83
+ if tile_shapes and len(tile_shapes) > 0:
84
+ block_size = tile_shapes[0][0]
85
+ return max(256, block_size)
86
+ else:
87
+ return 2048
88
+
89
+
90
+ def swiglu_forward(a, b):
91
+ if not a.is_contiguous():
92
+ a = a.contiguous()
93
+ if not b.is_contiguous():
94
+ b = b.contiguous()
95
+
96
+ total_elements = a.numel()
97
+ c = torch.empty_like(a)
98
+
99
+ block_size = get_optimal_block_size(total_elements, is_backward=False)
100
+
101
+ num_cores = get_npu_core_count()
102
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
103
+
104
+ _swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
105
+ return c
106
+
107
+
108
+ def swiglu_backward(a, b, dc):
109
+ if not dc.is_contiguous():
110
+ dc = dc.contiguous()
111
+ if not a.is_contiguous():
112
+ a = a.contiguous()
113
+ if not b.is_contiguous():
114
+ b = b.contiguous()
115
+
116
+ total_elements = dc.numel()
117
+ grad_a = torch.empty_like(a)
118
+ grad_b = torch.empty_like(b)
119
+
120
+ block_size = get_optimal_block_size(total_elements, is_backward=True)
121
+
122
+ num_cores = get_npu_core_count()
123
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
124
+
125
+ _swiglu_backward_kernel_flat[(grid_size,)](
126
+ dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
127
+ )
128
+ return grad_a, grad_b
129
+
130
+
131
+ class LigerSiLUMulFunction(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(ctx, a, b):
134
+ c = swiglu_forward(a, b)
135
+ ctx.save_for_backward(a, b)
136
+ return c
137
+
138
+ @staticmethod
139
+ def backward(ctx, dc):
140
+ a, b = ctx.saved_tensors
141
+ grad_a, grad_b = swiglu_backward(a, b, dc)
142
+ return grad_a, grad_b
@@ -0,0 +1,223 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import get_npu_core_count
11
+
12
+ MAX_FUSED_SIZE = 65536 // 4
13
+
14
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
15
+
16
+
17
+ @triton.jit
18
+ def _tv_distance_kernel(
19
+ p_ptr,
20
+ p_stride,
21
+ q_ptr,
22
+ q_stride,
23
+ loss_ptr,
24
+ loss_stride,
25
+ grads_ptr,
26
+ grads_stride,
27
+ label_ptr,
28
+ ignore_index: tl.constexpr,
29
+ n_cols, # V
30
+ total_rows: tl.constexpr, # BT
31
+ BLOCK_SIZE: tl.constexpr,
32
+ HAS_LABEL: tl.constexpr,
33
+ NUM_STAGES: tl.constexpr,
34
+ reduction: tl.constexpr = "batchmean",
35
+ ):
36
+ thread_id = tl.program_id(0)
37
+ num_threads = tl.num_programs(0)
38
+
39
+ for pid in tl.range(thread_id, total_rows, num_threads, num_stages=NUM_STAGES):
40
+ p_row_ptr = p_ptr + pid * p_stride
41
+ q_row_ptr = q_ptr + pid * q_stride
42
+ loss_row_ptr = loss_ptr + pid * loss_stride
43
+ grads_row_ptr = grads_ptr + pid * grads_stride
44
+ label_row_ptr = label_ptr + pid
45
+
46
+ base_offsets = tl.arange(0, BLOCK_SIZE)
47
+
48
+ should_skip = False
49
+ if HAS_LABEL:
50
+ label = tl.load(label_row_ptr)
51
+ if label == ignore_index:
52
+ should_skip = True
53
+
54
+ if should_skip:
55
+ for i in range(0, n_cols, BLOCK_SIZE):
56
+ offsets = i + base_offsets
57
+ mask = offsets < n_cols
58
+ tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
59
+ if reduction == "none":
60
+ tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
61
+ else:
62
+ loss_sum = 0.0
63
+ for i in range(0, n_cols, BLOCK_SIZE):
64
+ offsets = i + base_offsets
65
+ mask = offsets < n_cols
66
+
67
+ p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
68
+ q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
69
+
70
+ # TVD(P || Q) = 0.5 * |P - Q|
71
+ tv_loss = 0.5 * tl.abs(p - q)
72
+ grad_res = tl.where(p > q, 0.5, -0.5)
73
+
74
+ tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
75
+
76
+ if reduction == "none":
77
+ tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
78
+ else:
79
+ loss_sum += tl.sum(tv_loss, axis=0)
80
+
81
+ if reduction != "none":
82
+ tl.store(loss_row_ptr, loss_sum)
83
+
84
+
85
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
86
+ BT, V = p.shape
87
+
88
+ # TVD forward tiling strategy
89
+ # - In main loop (calculate loss and grad):
90
+ # * p: BLOCK_Q elements
91
+ # * q: BLOCK_Q elements
92
+ # * tv_loss: BLOCK_Q elements
93
+ # * grad_res: BLOCK_Q elements
94
+ # * loss_sum: BLOCK_Q elements (when reduction != "none")
95
+ # * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
96
+ # - Since loss_sum is not necessarily used in every calculation,
97
+ # - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
98
+ # - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
99
+ # - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
100
+ # - shapes: ((V,),)
101
+ # - tiling_dims: (0,) means first dimension of each shape can be tiled
102
+ # - Returns: ((block_size,),
103
+ shapes = ((V,),)
104
+ tile_shapes = compute_default_tiling_strategy(
105
+ safety_margin=0.80,
106
+ # In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
107
+ dtype_size=4,
108
+ memory_multiplier=5.0,
109
+ shapes=shapes,
110
+ tiling_dims=(0,),
111
+ )
112
+
113
+ if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
114
+ # Strategy returns ((block_size,),)
115
+ BLOCK_SIZE = tile_shapes[0][0]
116
+ else:
117
+ # Fallback to desired block size if no best practice found (no tiling needed)
118
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
119
+
120
+ num_cores = get_npu_core_count()
121
+ grid = (min(num_cores, BT),)
122
+
123
+ out_size = (BT, V) if reduction == "none" else (BT,)
124
+
125
+ # The loss and grid accumulation on BF16 platform of NPU will have precision errors.
126
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
127
+ grads = torch.empty_like(p, dtype=torch.float32)
128
+
129
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
130
+
131
+ _tv_distance_kernel[grid](
132
+ p,
133
+ p.stride(0),
134
+ q,
135
+ q.stride(0),
136
+ output_tensor,
137
+ output_tensor.stride(0),
138
+ grads,
139
+ grads.stride(0),
140
+ shift_labels if has_label else torch.empty(1, device=p.device),
141
+ ignore_index,
142
+ V,
143
+ BT,
144
+ BLOCK_SIZE=BLOCK_SIZE,
145
+ HAS_LABEL=has_label,
146
+ NUM_STAGES=3 if BT < 4096 else 4,
147
+ reduction=reduction,
148
+ )
149
+
150
+ if reduction == "batchmean":
151
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
152
+ elif reduction == "sum":
153
+ return output_tensor.sum(dim=0), grads
154
+ elif reduction == "mean":
155
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
156
+ else:
157
+ return output_tensor, grads
158
+
159
+
160
+ def tvd_backward_triton(grad_output, grads):
161
+ # If this is the last layer, grad_output is 1.0. Skip the mul then.
162
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
163
+ return grads
164
+
165
+ return grads * grad_output
166
+
167
+
168
+ class LigerTVDLossFunction(torch.autograd.Function):
169
+ """
170
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
171
+ """
172
+
173
+ @staticmethod
174
+ @ensure_contiguous
175
+ def forward(
176
+ ctx,
177
+ p: torch.Tensor,
178
+ q: torch.Tensor,
179
+ shift_labels: Optional[torch.Tensor] = None,
180
+ reduction: REDUCTION_LITERAL = "batchmean",
181
+ ignore_index: int = -100,
182
+ ) -> torch.Tensor:
183
+ """A forward pass for the Total Variation Distance Loss.
184
+
185
+ Args:
186
+ ctx: Torch autograd context
187
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
188
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
189
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
190
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
191
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
192
+
193
+ Returns:
194
+ torch.Tensor: The computed Total Variation Distance Loss.
195
+ """
196
+ has_label = False
197
+ if shift_labels is not None:
198
+ assert shift_labels.shape == (p.shape[0],), (
199
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
200
+ )
201
+ shift_labels = shift_labels.contiguous()
202
+ has_label = True
203
+
204
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
205
+ ctx.save_for_backward(grads)
206
+ return loss
207
+
208
+ @staticmethod
209
+ @ensure_contiguous
210
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
211
+ """A backward pass for the Total Variation Distance Loss.
212
+
213
+ Args:
214
+ ctx: Torch autograd context
215
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
216
+
217
+ Returns:
218
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
219
+ """
220
+ (grads,) = ctx.saved_tensors
221
+ grads = tvd_backward_triton(grad_output, grads)
222
+
223
+ return grads, None, None, None, None