liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def _triton_rope_npu(
|
|
10
|
+
q_ptr,
|
|
11
|
+
q_row_stride,
|
|
12
|
+
k_ptr,
|
|
13
|
+
k_row_stride,
|
|
14
|
+
cos,
|
|
15
|
+
cos_row_stride,
|
|
16
|
+
sin,
|
|
17
|
+
sin_row_stride,
|
|
18
|
+
sl,
|
|
19
|
+
bs: tl.constexpr,
|
|
20
|
+
cos_bs: tl.constexpr,
|
|
21
|
+
n_qh: tl.constexpr,
|
|
22
|
+
n_kh: tl.constexpr,
|
|
23
|
+
hd: tl.constexpr,
|
|
24
|
+
BLOCK_Q: tl.constexpr,
|
|
25
|
+
BLOCK_K: tl.constexpr,
|
|
26
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
27
|
+
):
|
|
28
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
29
|
+
batch_idx = pid // sl
|
|
30
|
+
cos_row_idx = pid % sl
|
|
31
|
+
|
|
32
|
+
cos = cos + tl.where(
|
|
33
|
+
cos_bs == 1,
|
|
34
|
+
cos_row_idx * cos_row_stride,
|
|
35
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
|
36
|
+
)
|
|
37
|
+
sin = sin + tl.where(
|
|
38
|
+
cos_bs == 1,
|
|
39
|
+
cos_row_idx * sin_row_stride,
|
|
40
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
q_base = q_ptr + pid * q_row_stride
|
|
44
|
+
k_base = k_ptr + pid * k_row_stride
|
|
45
|
+
|
|
46
|
+
# Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads)
|
|
47
|
+
d_idx = tl.arange(0, hd // 2)
|
|
48
|
+
d_mask = d_idx < (hd // 2) # Always True, but kept for clarity
|
|
49
|
+
cos_vals = tl.load(cos + d_idx, mask=d_mask, other=0)
|
|
50
|
+
sin_vals = tl.load(sin + d_idx, mask=d_mask, other=0)
|
|
51
|
+
|
|
52
|
+
# Process q heads in chunks to prevent UB overflow
|
|
53
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
54
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
55
|
+
qh_mask = qh_idx < n_qh
|
|
56
|
+
|
|
57
|
+
# block_mask: qh_mask broadcasted over d_idx dimension
|
|
58
|
+
block_mask = qh_mask[:, None]
|
|
59
|
+
|
|
60
|
+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
|
|
61
|
+
|
|
62
|
+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
|
|
63
|
+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
64
|
+
|
|
65
|
+
if not BACKWARD_PASS:
|
|
66
|
+
new_left = q_left * cos_vals - q_right * sin_vals
|
|
67
|
+
new_right = q_right * cos_vals + q_left * sin_vals
|
|
68
|
+
else:
|
|
69
|
+
new_left = q_left * cos_vals + q_right * sin_vals
|
|
70
|
+
new_right = q_right * cos_vals - q_left * sin_vals
|
|
71
|
+
|
|
72
|
+
tl.store(q_base + offsets, new_left, mask=block_mask)
|
|
73
|
+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
74
|
+
|
|
75
|
+
# Process k heads in chunks to prevent UB overflow
|
|
76
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
77
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
78
|
+
kh_mask = kh_idx < n_kh
|
|
79
|
+
|
|
80
|
+
# block_mask: kh_mask broadcasted over d_idx dimension
|
|
81
|
+
block_mask = kh_mask[:, None]
|
|
82
|
+
|
|
83
|
+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
|
|
84
|
+
|
|
85
|
+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
|
|
86
|
+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
|
|
87
|
+
|
|
88
|
+
if not BACKWARD_PASS:
|
|
89
|
+
new_left = k_left * cos_vals - k_right * sin_vals
|
|
90
|
+
new_right = k_right * cos_vals + k_left * sin_vals
|
|
91
|
+
else:
|
|
92
|
+
new_left = k_left * cos_vals + k_right * sin_vals
|
|
93
|
+
new_right = k_right * cos_vals - k_left * sin_vals
|
|
94
|
+
|
|
95
|
+
tl.store(k_base + offsets, new_left, mask=block_mask)
|
|
96
|
+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def rope_forward(q, k, cos, sin):
|
|
100
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
101
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
102
|
+
q = q.transpose(1, 2)
|
|
103
|
+
k = k.transpose(1, 2)
|
|
104
|
+
|
|
105
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
106
|
+
n_kv_head = k.shape[2]
|
|
107
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
108
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
109
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
110
|
+
|
|
111
|
+
n_row = batch_size * seq_len
|
|
112
|
+
|
|
113
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
114
|
+
q = q.contiguous()
|
|
115
|
+
k = k.contiguous()
|
|
116
|
+
cos = cos.contiguous()
|
|
117
|
+
sin = sin.contiguous()
|
|
118
|
+
cos_batch_size = cos.shape[0]
|
|
119
|
+
|
|
120
|
+
# Compute tiling strategy based on UB capacity
|
|
121
|
+
dtype_size = q.element_size()
|
|
122
|
+
# ROPE forward tiling strategy (based on optimized ROPE kernel):
|
|
123
|
+
# - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
|
|
124
|
+
# - In q heads loop (peak memory):
|
|
125
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
126
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
127
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
128
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
129
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
130
|
+
# - In k heads loop (peak memory):
|
|
131
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
132
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
133
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
134
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
135
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
136
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
137
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
138
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
139
|
+
# - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
|
|
140
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
141
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
142
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
143
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
144
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
145
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
146
|
+
safety_margin=0.90,
|
|
147
|
+
dtype_size=dtype_size,
|
|
148
|
+
memory_multiplier=3.0,
|
|
149
|
+
shapes=shapes,
|
|
150
|
+
tiling_dims=(0, 0),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
154
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
155
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
156
|
+
BLOCK_Q, _ = q_tile_shape
|
|
157
|
+
BLOCK_K, _ = k_tile_shape
|
|
158
|
+
else:
|
|
159
|
+
# Fallback to conservative defaults
|
|
160
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
161
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
162
|
+
|
|
163
|
+
_triton_rope_npu[(n_row,)](
|
|
164
|
+
q,
|
|
165
|
+
q.stride(1),
|
|
166
|
+
k,
|
|
167
|
+
k.stride(1),
|
|
168
|
+
cos,
|
|
169
|
+
cos.stride(-2),
|
|
170
|
+
sin,
|
|
171
|
+
sin.stride(-2),
|
|
172
|
+
seq_len,
|
|
173
|
+
batch_size,
|
|
174
|
+
cos_batch_size,
|
|
175
|
+
n_q_head,
|
|
176
|
+
n_kv_head,
|
|
177
|
+
head_dim,
|
|
178
|
+
BLOCK_Q,
|
|
179
|
+
BLOCK_K,
|
|
180
|
+
BACKWARD_PASS=False,
|
|
181
|
+
)
|
|
182
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def rope_backward(dq, dk, cos, sin):
|
|
186
|
+
dq = dq.transpose(1, 2)
|
|
187
|
+
dk = dk.transpose(1, 2)
|
|
188
|
+
|
|
189
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
190
|
+
cos_batch_size = cos.shape[0]
|
|
191
|
+
n_kv_head = dk.shape[2]
|
|
192
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
193
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
194
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
195
|
+
|
|
196
|
+
n_row = batch_size * seq_len
|
|
197
|
+
|
|
198
|
+
# ensure dq and dk are contiguous
|
|
199
|
+
dq = dq.contiguous()
|
|
200
|
+
dk = dk.contiguous()
|
|
201
|
+
|
|
202
|
+
# Compute tiling strategy based on UB capacity
|
|
203
|
+
dtype_size = dq.element_size()
|
|
204
|
+
# ROPE backward tiling strategy (based on optimized ROPE kernel):
|
|
205
|
+
# - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
|
|
206
|
+
# - In q heads loop (peak memory):
|
|
207
|
+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
|
|
208
|
+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
|
|
209
|
+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
210
|
+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
|
|
211
|
+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
|
|
212
|
+
# - In k heads loop (peak memory):
|
|
213
|
+
# * k_left: BLOCK_K * (pad_hd // 2) elements
|
|
214
|
+
# * k_right: BLOCK_K * (pad_hd // 2) elements
|
|
215
|
+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
216
|
+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
|
|
217
|
+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
|
|
218
|
+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
|
|
219
|
+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
|
|
220
|
+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
|
|
221
|
+
# - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
|
|
222
|
+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
223
|
+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
224
|
+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
|
|
225
|
+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
226
|
+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
227
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
228
|
+
safety_margin=0.90,
|
|
229
|
+
dtype_size=dtype_size,
|
|
230
|
+
memory_multiplier=3.0,
|
|
231
|
+
shapes=shapes,
|
|
232
|
+
tiling_dims=(0, 0),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
236
|
+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
237
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
238
|
+
BLOCK_Q, _ = q_tile_shape
|
|
239
|
+
BLOCK_K, _ = k_tile_shape
|
|
240
|
+
else:
|
|
241
|
+
# Fallback to conservative defaults
|
|
242
|
+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
|
|
243
|
+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
|
|
244
|
+
|
|
245
|
+
_triton_rope_npu[(n_row,)](
|
|
246
|
+
dq,
|
|
247
|
+
dq.stride(1),
|
|
248
|
+
dk,
|
|
249
|
+
dk.stride(1),
|
|
250
|
+
cos,
|
|
251
|
+
cos.stride(-2),
|
|
252
|
+
sin,
|
|
253
|
+
sin.stride(-2),
|
|
254
|
+
seq_len,
|
|
255
|
+
batch_size,
|
|
256
|
+
cos_batch_size,
|
|
257
|
+
n_q_head,
|
|
258
|
+
n_kv_head,
|
|
259
|
+
head_dim,
|
|
260
|
+
BLOCK_Q,
|
|
261
|
+
BLOCK_K,
|
|
262
|
+
BACKWARD_PASS=True,
|
|
263
|
+
)
|
|
264
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class LigerRopeFunction(torch.autograd.Function):
|
|
268
|
+
@staticmethod
|
|
269
|
+
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
270
|
+
"""
|
|
271
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
272
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
273
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
274
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
275
|
+
"""
|
|
276
|
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
277
|
+
ctx.save_for_backward(cos, sin)
|
|
278
|
+
return q, k
|
|
279
|
+
|
|
280
|
+
def backward(ctx, dq, dk):
|
|
281
|
+
"""
|
|
282
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
283
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
284
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
285
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
cos, sin = ctx.saved_tensors
|
|
289
|
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
|
290
|
+
return dq, dk, None, None, None, None
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
7
|
+
|
|
8
|
+
# -----------------------------------------------------------------------------
|
|
9
|
+
# Kernels (High-performance 1D Flatten Implementation)
|
|
10
|
+
# -----------------------------------------------------------------------------
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@triton.jit
|
|
14
|
+
def _swiglu_forward_kernel_flat(
|
|
15
|
+
a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
|
|
16
|
+
):
|
|
17
|
+
pid = tl.program_id(0)
|
|
18
|
+
num_progs = tl.num_programs(0)
|
|
19
|
+
|
|
20
|
+
# Grid-Stride Loop
|
|
21
|
+
start_idx = pid * BLOCK_SIZE
|
|
22
|
+
stride = num_progs * BLOCK_SIZE
|
|
23
|
+
|
|
24
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
25
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
26
|
+
mask = offsets < total_elements
|
|
27
|
+
|
|
28
|
+
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
29
|
+
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
30
|
+
res = (a_val * tl.sigmoid(a_val)) * b_val
|
|
31
|
+
tl.store(c_ptr + offsets, res, mask=mask)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@triton.jit
|
|
35
|
+
def _swiglu_backward_kernel_flat(
|
|
36
|
+
dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
|
|
37
|
+
):
|
|
38
|
+
pid = tl.program_id(0)
|
|
39
|
+
num_progs = tl.num_programs(0)
|
|
40
|
+
start_idx = pid * BLOCK_SIZE
|
|
41
|
+
stride = num_progs * BLOCK_SIZE
|
|
42
|
+
|
|
43
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
44
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
45
|
+
mask = offsets < total_elements
|
|
46
|
+
|
|
47
|
+
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
48
|
+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
49
|
+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
50
|
+
|
|
51
|
+
sig_a = tl.sigmoid(a)
|
|
52
|
+
silu_a = a * sig_a
|
|
53
|
+
term1 = silu_a * (1.0 - sig_a) + sig_a
|
|
54
|
+
|
|
55
|
+
db = dc * silu_a
|
|
56
|
+
da = dc * b * term1
|
|
57
|
+
|
|
58
|
+
tl.store(da_ptr + offsets, da, mask=mask)
|
|
59
|
+
tl.store(db_ptr + offsets, db, mask=mask)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# -----------------------------------------------------------------------------
|
|
63
|
+
# Helper: Call compute_default_tiling_strategy
|
|
64
|
+
# -----------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_optimal_block_size(total_elements, is_backward=False):
|
|
68
|
+
"""
|
|
69
|
+
Calculate optimal Block Size using compute_default_tiling_strategy
|
|
70
|
+
"""
|
|
71
|
+
# 1. Set Memory Multiplier
|
|
72
|
+
# Forward is lighter, Backward requires more memory for intermediate variables
|
|
73
|
+
# 8.0 and 12.0 are empirical values based on 910B UB (192KB)
|
|
74
|
+
multiplier = 12.0 if is_backward else 8.0
|
|
75
|
+
|
|
76
|
+
# 2. Call calculation function
|
|
77
|
+
# Treat input as 1D (total_elements,), only tiling on dim 0
|
|
78
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
79
|
+
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# 3. Parse result
|
|
83
|
+
if tile_shapes and len(tile_shapes) > 0:
|
|
84
|
+
block_size = tile_shapes[0][0]
|
|
85
|
+
return max(256, block_size)
|
|
86
|
+
else:
|
|
87
|
+
return 2048
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def swiglu_forward(a, b):
|
|
91
|
+
if not a.is_contiguous():
|
|
92
|
+
a = a.contiguous()
|
|
93
|
+
if not b.is_contiguous():
|
|
94
|
+
b = b.contiguous()
|
|
95
|
+
|
|
96
|
+
total_elements = a.numel()
|
|
97
|
+
c = torch.empty_like(a)
|
|
98
|
+
|
|
99
|
+
block_size = get_optimal_block_size(total_elements, is_backward=False)
|
|
100
|
+
|
|
101
|
+
num_cores = get_npu_core_count()
|
|
102
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
103
|
+
|
|
104
|
+
_swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
|
|
105
|
+
return c
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def swiglu_backward(a, b, dc):
|
|
109
|
+
if not dc.is_contiguous():
|
|
110
|
+
dc = dc.contiguous()
|
|
111
|
+
if not a.is_contiguous():
|
|
112
|
+
a = a.contiguous()
|
|
113
|
+
if not b.is_contiguous():
|
|
114
|
+
b = b.contiguous()
|
|
115
|
+
|
|
116
|
+
total_elements = dc.numel()
|
|
117
|
+
grad_a = torch.empty_like(a)
|
|
118
|
+
grad_b = torch.empty_like(b)
|
|
119
|
+
|
|
120
|
+
block_size = get_optimal_block_size(total_elements, is_backward=True)
|
|
121
|
+
|
|
122
|
+
num_cores = get_npu_core_count()
|
|
123
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
124
|
+
|
|
125
|
+
_swiglu_backward_kernel_flat[(grid_size,)](
|
|
126
|
+
dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
|
|
127
|
+
)
|
|
128
|
+
return grad_a, grad_b
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
132
|
+
@staticmethod
|
|
133
|
+
def forward(ctx, a, b):
|
|
134
|
+
c = swiglu_forward(a, b)
|
|
135
|
+
ctx.save_for_backward(a, b)
|
|
136
|
+
return c
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def backward(ctx, dc):
|
|
140
|
+
a, b = ctx.saved_tensors
|
|
141
|
+
grad_a, grad_b = swiglu_backward(a, b, dc)
|
|
142
|
+
return grad_a, grad_b
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
MAX_FUSED_SIZE = 65536 // 4
|
|
12
|
+
|
|
13
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@triton.jit
|
|
17
|
+
def _tv_distance_kernel(
|
|
18
|
+
p_ptr,
|
|
19
|
+
p_stride,
|
|
20
|
+
q_ptr,
|
|
21
|
+
q_stride,
|
|
22
|
+
loss_ptr,
|
|
23
|
+
loss_stride,
|
|
24
|
+
grads_ptr,
|
|
25
|
+
grads_stride,
|
|
26
|
+
label_ptr,
|
|
27
|
+
ignore_index: tl.constexpr,
|
|
28
|
+
n_cols, # V
|
|
29
|
+
total_rows: tl.constexpr, # BT
|
|
30
|
+
BLOCK_SIZE: tl.constexpr,
|
|
31
|
+
HAS_LABEL: tl.constexpr,
|
|
32
|
+
reduction: tl.constexpr = "batchmean",
|
|
33
|
+
):
|
|
34
|
+
thread_id = tl.program_id(0)
|
|
35
|
+
num_threads = tl.num_programs(0)
|
|
36
|
+
|
|
37
|
+
for pid in range(thread_id, total_rows, num_threads):
|
|
38
|
+
p_row_ptr = p_ptr + pid * p_stride
|
|
39
|
+
q_row_ptr = q_ptr + pid * q_stride
|
|
40
|
+
loss_row_ptr = loss_ptr + pid * loss_stride
|
|
41
|
+
grads_row_ptr = grads_ptr + pid * grads_stride
|
|
42
|
+
label_row_ptr = label_ptr + pid
|
|
43
|
+
|
|
44
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
45
|
+
|
|
46
|
+
should_skip = False
|
|
47
|
+
if HAS_LABEL:
|
|
48
|
+
label = tl.load(label_row_ptr)
|
|
49
|
+
if label == ignore_index:
|
|
50
|
+
should_skip = True
|
|
51
|
+
|
|
52
|
+
if should_skip:
|
|
53
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
54
|
+
offsets = i + base_offsets
|
|
55
|
+
mask = offsets < n_cols
|
|
56
|
+
tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
|
|
57
|
+
if reduction == "none":
|
|
58
|
+
tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
|
|
59
|
+
else:
|
|
60
|
+
loss_sum = 0.0
|
|
61
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
62
|
+
offsets = i + base_offsets
|
|
63
|
+
mask = offsets < n_cols
|
|
64
|
+
|
|
65
|
+
p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
|
|
66
|
+
q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
|
|
67
|
+
|
|
68
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
|
69
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
|
70
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
|
71
|
+
|
|
72
|
+
tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
|
|
73
|
+
|
|
74
|
+
if reduction == "none":
|
|
75
|
+
tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
|
|
76
|
+
else:
|
|
77
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
|
78
|
+
|
|
79
|
+
if reduction != "none":
|
|
80
|
+
tl.store(loss_row_ptr, loss_sum)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
84
|
+
BT, V = p.shape
|
|
85
|
+
|
|
86
|
+
# TVD forward tiling strategy
|
|
87
|
+
# - In main loop (calculate loss and grad):
|
|
88
|
+
# * p: BLOCK_Q elements
|
|
89
|
+
# * q: BLOCK_Q elements
|
|
90
|
+
# * tv_loss: BLOCK_Q elements
|
|
91
|
+
# * grad_res: BLOCK_Q elements
|
|
92
|
+
# * loss_sum: BLOCK_Q elements (when reduction != "none")
|
|
93
|
+
# * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
|
|
94
|
+
# - Since loss_sum is not necessarily used in every calculation,
|
|
95
|
+
# - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
|
|
96
|
+
# - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
|
|
97
|
+
# - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
98
|
+
# - shapes: ((V,),)
|
|
99
|
+
# - tiling_dims: (0,) means first dimension of each shape can be tiled
|
|
100
|
+
# - Returns: ((block_size,),
|
|
101
|
+
shapes = ((V,),)
|
|
102
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
103
|
+
safety_margin=0.80,
|
|
104
|
+
# In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
|
|
105
|
+
dtype_size=4,
|
|
106
|
+
memory_multiplier=5.0,
|
|
107
|
+
shapes=shapes,
|
|
108
|
+
tiling_dims=(0,),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
|
|
112
|
+
# Strategy returns ((block_size,),)
|
|
113
|
+
BLOCK_SIZE = tile_shapes[0][0]
|
|
114
|
+
else:
|
|
115
|
+
# Fallback to desired block size if no best practice found (no tiling needed)
|
|
116
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
117
|
+
|
|
118
|
+
MAX_BATCH_PER_KERNEL = 65535 # The maximum processing capacity of each kernel in npu
|
|
119
|
+
if BT <= MAX_BATCH_PER_KERNEL:
|
|
120
|
+
grid = (BT,)
|
|
121
|
+
else:
|
|
122
|
+
grid = (MAX_BATCH_PER_KERNEL,)
|
|
123
|
+
|
|
124
|
+
out_size = (BT, V) if reduction == "none" else (BT,)
|
|
125
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
|
126
|
+
grads = torch.empty_like(p)
|
|
127
|
+
|
|
128
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
|
129
|
+
|
|
130
|
+
_tv_distance_kernel[grid](
|
|
131
|
+
p,
|
|
132
|
+
p.stride(0),
|
|
133
|
+
q,
|
|
134
|
+
q.stride(0),
|
|
135
|
+
output_tensor,
|
|
136
|
+
output_tensor.stride(0),
|
|
137
|
+
grads,
|
|
138
|
+
grads.stride(0),
|
|
139
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
|
140
|
+
ignore_index,
|
|
141
|
+
V,
|
|
142
|
+
BT,
|
|
143
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
144
|
+
HAS_LABEL=has_label,
|
|
145
|
+
reduction=reduction,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if reduction == "batchmean":
|
|
149
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
|
150
|
+
elif reduction == "sum":
|
|
151
|
+
return output_tensor.sum(dim=0), grads
|
|
152
|
+
elif reduction == "mean":
|
|
153
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
|
154
|
+
else:
|
|
155
|
+
return output_tensor, grads
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def tvd_backward_triton(grad_output, grads):
|
|
159
|
+
# If this is the last layer, grad_output is 1.0. Skip the mul then.
|
|
160
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
161
|
+
return grads
|
|
162
|
+
|
|
163
|
+
return grads * grad_output
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
|
167
|
+
"""
|
|
168
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
@ensure_contiguous
|
|
173
|
+
def forward(
|
|
174
|
+
ctx,
|
|
175
|
+
p: torch.Tensor,
|
|
176
|
+
q: torch.Tensor,
|
|
177
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
178
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
179
|
+
ignore_index: int = -100,
|
|
180
|
+
) -> torch.Tensor:
|
|
181
|
+
"""A forward pass for the Total Variation Distance Loss.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
ctx: Torch autograd context
|
|
185
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
|
186
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
|
187
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
|
188
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
|
189
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
|
193
|
+
"""
|
|
194
|
+
has_label = False
|
|
195
|
+
if shift_labels is not None:
|
|
196
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
197
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
198
|
+
)
|
|
199
|
+
shift_labels = shift_labels.contiguous()
|
|
200
|
+
has_label = True
|
|
201
|
+
|
|
202
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
203
|
+
ctx.save_for_backward(grads)
|
|
204
|
+
return loss
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
@ensure_contiguous
|
|
208
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
209
|
+
"""A backward pass for the Total Variation Distance Loss.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
ctx: Torch autograd context
|
|
213
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
|
217
|
+
"""
|
|
218
|
+
(grads,) = ctx.saved_tensors
|
|
219
|
+
grads = tvd_backward_triton(grad_output, grads)
|
|
220
|
+
|
|
221
|
+
return grads, None, None, None, None
|