liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
|
+
try:
|
|
15
|
+
from triton.language.extra.libdevice import rsqrt
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
18
|
+
else:
|
|
19
|
+
from triton.language.math import rsqrt
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@triton.jit
|
|
23
|
+
def _poly_norm_forward_kernel(
|
|
24
|
+
Y_ptr,
|
|
25
|
+
Y_row_stride,
|
|
26
|
+
X_ptr,
|
|
27
|
+
X_row_stride,
|
|
28
|
+
W_ptr, # weight: [3] for [w0, w1, w2]
|
|
29
|
+
B_ptr, # bias: scalar
|
|
30
|
+
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
|
|
31
|
+
RSTD_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
eps,
|
|
34
|
+
BLOCK_SIZE: tl.constexpr,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
PolyNorm formula:
|
|
38
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
39
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
40
|
+
|
|
41
|
+
Reference:
|
|
42
|
+
1. https://github.com/BryceZhuo/PolyCom/
|
|
43
|
+
2. https://arxiv.org/pdf/2411.03884
|
|
44
|
+
|
|
45
|
+
Cache rstd values for backward pass
|
|
46
|
+
"""
|
|
47
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
48
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
49
|
+
mask = col_offsets < n_cols
|
|
50
|
+
|
|
51
|
+
# Load pointers
|
|
52
|
+
Y_ptr += row_idx * Y_row_stride
|
|
53
|
+
X_ptr += row_idx * X_row_stride
|
|
54
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
55
|
+
|
|
56
|
+
# Load input row
|
|
57
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
58
|
+
|
|
59
|
+
# Load weights and bias
|
|
60
|
+
w0 = tl.load(W_ptr + 0)
|
|
61
|
+
w1 = tl.load(W_ptr + 1)
|
|
62
|
+
w2 = tl.load(W_ptr + 2)
|
|
63
|
+
b = tl.load(B_ptr)
|
|
64
|
+
|
|
65
|
+
# Compute x³, x², x
|
|
66
|
+
X_pow3 = X_row * X_row * X_row
|
|
67
|
+
X_pow2 = X_row * X_row
|
|
68
|
+
X_pow1 = X_row
|
|
69
|
+
|
|
70
|
+
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
|
|
71
|
+
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
|
|
72
|
+
rstd_3 = rsqrt(mean_square_3 + eps)
|
|
73
|
+
norm_x3 = X_pow3 * rstd_3
|
|
74
|
+
|
|
75
|
+
# Compute norm(x²)
|
|
76
|
+
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
|
|
77
|
+
rstd_2 = rsqrt(mean_square_2 + eps)
|
|
78
|
+
norm_x2 = X_pow2 * rstd_2
|
|
79
|
+
|
|
80
|
+
# Compute norm(x)
|
|
81
|
+
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
|
|
82
|
+
rstd_1 = rsqrt(mean_square_1 + eps)
|
|
83
|
+
norm_x1 = X_pow1 * rstd_1
|
|
84
|
+
|
|
85
|
+
# Cache rstd values for backward
|
|
86
|
+
tl.store(RSTD_ptr + 0, rstd_3)
|
|
87
|
+
tl.store(RSTD_ptr + 1, rstd_2)
|
|
88
|
+
tl.store(RSTD_ptr + 2, rstd_1)
|
|
89
|
+
|
|
90
|
+
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
91
|
+
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
|
|
92
|
+
|
|
93
|
+
# Store output
|
|
94
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@triton.jit
|
|
98
|
+
def _poly_norm_backward_kernel(
|
|
99
|
+
dY_ptr,
|
|
100
|
+
dY_row_stride,
|
|
101
|
+
dX_ptr,
|
|
102
|
+
dX_row_stride,
|
|
103
|
+
X_ptr,
|
|
104
|
+
X_row_stride,
|
|
105
|
+
W_ptr,
|
|
106
|
+
RSTD_ptr,
|
|
107
|
+
RSTD_row_stride,
|
|
108
|
+
dW_ptr, # shape: (n_programs, 3)
|
|
109
|
+
dW_row_stride,
|
|
110
|
+
dB_ptr, # shape: (n_programs,)
|
|
111
|
+
n_rows,
|
|
112
|
+
n_cols,
|
|
113
|
+
rows_per_program: tl.constexpr,
|
|
114
|
+
BLOCK_SIZE: tl.constexpr,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
PolyNorm Backward Kernel Gradient:
|
|
118
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
119
|
+
|
|
120
|
+
where:
|
|
121
|
+
- D_p = RMS(x^p) = 1/rstd_p
|
|
122
|
+
- S_p = sum(grad * x^p) over the row
|
|
123
|
+
- d = n_cols
|
|
124
|
+
- p ∈ {3, 2, 1}
|
|
125
|
+
"""
|
|
126
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
127
|
+
row_start = row_block_id * rows_per_program
|
|
128
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
129
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
130
|
+
mask = col_offsets < n_cols
|
|
131
|
+
|
|
132
|
+
# Initialize accumulators for weight and bias gradients (scalars)
|
|
133
|
+
dW0_acc = 0.0
|
|
134
|
+
dW1_acc = 0.0
|
|
135
|
+
dW2_acc = 0.0
|
|
136
|
+
dB_acc = 0.0
|
|
137
|
+
|
|
138
|
+
# Load weights
|
|
139
|
+
w0 = tl.load(W_ptr + 0).to(tl.float32)
|
|
140
|
+
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
141
|
+
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
142
|
+
|
|
143
|
+
dY_ptr += row_start * dY_row_stride
|
|
144
|
+
dX_ptr += row_start * dX_row_stride
|
|
145
|
+
X_ptr += row_start * X_row_stride
|
|
146
|
+
RSTD_ptr += row_start * RSTD_row_stride
|
|
147
|
+
|
|
148
|
+
for _ in range(row_start, row_end):
|
|
149
|
+
# Load input and gradient
|
|
150
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
152
|
+
|
|
153
|
+
# Load cached rstd values
|
|
154
|
+
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
|
|
155
|
+
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
|
|
156
|
+
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
|
|
157
|
+
|
|
158
|
+
# Compute powers
|
|
159
|
+
X_pow3 = X_row * X_row * X_row
|
|
160
|
+
X_pow2 = X_row * X_row
|
|
161
|
+
X_pow1 = X_row
|
|
162
|
+
|
|
163
|
+
# Accumulate bias gradient: dB = sum(dY)
|
|
164
|
+
dB_acc += tl.sum(dY_row, axis=0)
|
|
165
|
+
|
|
166
|
+
# Compute gradient w.r.t. input using closed-form formula
|
|
167
|
+
# For p=3: ∂L/∂x from w0 * norm(x³)
|
|
168
|
+
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
|
|
169
|
+
grad_x_3 = w0 * (
|
|
170
|
+
3.0 * X_pow2 * rstd_3 * dY_row
|
|
171
|
+
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# For p=2: ∂L/∂x from w1 * norm(x²)
|
|
175
|
+
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
|
|
176
|
+
grad_x_2 = w1 * (
|
|
177
|
+
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# For p=1: ∂L/∂x from w2 * norm(x)
|
|
181
|
+
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
|
|
182
|
+
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
|
|
183
|
+
|
|
184
|
+
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
|
|
185
|
+
dW0_acc += rstd_3 * S_3
|
|
186
|
+
dW1_acc += rstd_2 * S_2
|
|
187
|
+
dW2_acc += rstd_1 * S_1
|
|
188
|
+
|
|
189
|
+
# Total gradient
|
|
190
|
+
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
191
|
+
|
|
192
|
+
# Store gradient
|
|
193
|
+
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
|
|
194
|
+
|
|
195
|
+
# Update pointers
|
|
196
|
+
dY_ptr += dY_row_stride
|
|
197
|
+
dX_ptr += dX_row_stride
|
|
198
|
+
X_ptr += X_row_stride
|
|
199
|
+
RSTD_ptr += RSTD_row_stride
|
|
200
|
+
|
|
201
|
+
# Store accumulated gradients (scalars)
|
|
202
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
203
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
|
|
204
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
|
|
205
|
+
tl.store(dB_ptr + row_block_id, dB_acc)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def poly_norm_forward(X, W, B, eps=1e-6):
|
|
209
|
+
"""
|
|
210
|
+
PolyNorm Forward Pass
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
X: input tensor of shape (*, H) where H is hidden dimension
|
|
214
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
215
|
+
B: bias scalar tensor
|
|
216
|
+
eps: epsilon for numerical stability
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Y: output tensor of same shape as X
|
|
220
|
+
X: reshaped input (for backward)
|
|
221
|
+
RSTD: cached rstd values (for backward)
|
|
222
|
+
BLOCK_SIZE: block size used
|
|
223
|
+
num_warps: number of warps used
|
|
224
|
+
"""
|
|
225
|
+
shape = X.shape
|
|
226
|
+
dim = shape[-1]
|
|
227
|
+
X = X.view(-1, dim)
|
|
228
|
+
n_rows, n_cols = X.shape
|
|
229
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
230
|
+
|
|
231
|
+
# RSTD is to cache rstd for each row
|
|
232
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
233
|
+
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
|
|
234
|
+
|
|
235
|
+
# Check constraints
|
|
236
|
+
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
|
|
237
|
+
assert B.numel() == 1, "Bias must be a scalar"
|
|
238
|
+
|
|
239
|
+
# XPU-specific optimization
|
|
240
|
+
kernel_args = {}
|
|
241
|
+
if X.device.type == "xpu":
|
|
242
|
+
kernel_args["grf_mode"] = "large"
|
|
243
|
+
|
|
244
|
+
# Launch kernel
|
|
245
|
+
_poly_norm_forward_kernel[(n_rows,)](
|
|
246
|
+
Y,
|
|
247
|
+
Y.stride(0),
|
|
248
|
+
X,
|
|
249
|
+
X.stride(0),
|
|
250
|
+
W,
|
|
251
|
+
B,
|
|
252
|
+
RSTD,
|
|
253
|
+
RSTD.stride(0),
|
|
254
|
+
n_cols,
|
|
255
|
+
eps,
|
|
256
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
257
|
+
num_warps=num_warps,
|
|
258
|
+
**kernel_args,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
265
|
+
"""
|
|
266
|
+
PolyNorm Backward Pass
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
dY: gradient of output
|
|
270
|
+
X: input tensor (already reshaped to 2D)
|
|
271
|
+
W: weight tensor
|
|
272
|
+
RSTD: cached rstd values from forward
|
|
273
|
+
BLOCK_SIZE: block size from forward
|
|
274
|
+
num_warps: number of warps from forward
|
|
275
|
+
in_place: whether to in-place modify dY to store dX (saves memory)
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
dX: gradient w.r.t. input
|
|
279
|
+
dW: gradient w.r.t. weight
|
|
280
|
+
dB: gradient w.r.t. bias
|
|
281
|
+
"""
|
|
282
|
+
shape = dY.shape
|
|
283
|
+
dim = shape[-1]
|
|
284
|
+
dY = dY.view(-1, dim)
|
|
285
|
+
n_rows, n_cols = dY.shape
|
|
286
|
+
|
|
287
|
+
# Get number of SMs for parallelization
|
|
288
|
+
import math
|
|
289
|
+
|
|
290
|
+
sm_count = 1
|
|
291
|
+
if X.device.type == "cuda":
|
|
292
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
293
|
+
elif X.device.type == "xpu":
|
|
294
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
295
|
+
elif X.device.type == "npu":
|
|
296
|
+
sm_count = get_npu_multi_processor_count()
|
|
297
|
+
|
|
298
|
+
# Allocate or reuse gradients
|
|
299
|
+
if in_place is True:
|
|
300
|
+
dX = dY
|
|
301
|
+
else:
|
|
302
|
+
dX = torch.zeros_like(dY)
|
|
303
|
+
|
|
304
|
+
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
|
|
305
|
+
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
|
|
306
|
+
|
|
307
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
308
|
+
grid = (sm_count,)
|
|
309
|
+
|
|
310
|
+
# XPU-specific optimization
|
|
311
|
+
kernel_args = {}
|
|
312
|
+
if X.device.type == "xpu":
|
|
313
|
+
kernel_args["grf_mode"] = "large"
|
|
314
|
+
|
|
315
|
+
# Launch backward kernel
|
|
316
|
+
_poly_norm_backward_kernel[grid](
|
|
317
|
+
dY,
|
|
318
|
+
dY.stride(0),
|
|
319
|
+
dX,
|
|
320
|
+
dX.stride(0),
|
|
321
|
+
X,
|
|
322
|
+
X.stride(0),
|
|
323
|
+
W,
|
|
324
|
+
RSTD,
|
|
325
|
+
RSTD.stride(0),
|
|
326
|
+
_dW,
|
|
327
|
+
_dW.stride(0),
|
|
328
|
+
_dB,
|
|
329
|
+
n_rows,
|
|
330
|
+
n_cols,
|
|
331
|
+
rows_per_program,
|
|
332
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
|
+
num_warps=num_warps,
|
|
334
|
+
**kernel_args,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Reduce gradients across SMs
|
|
338
|
+
dX = dX.view(*shape)
|
|
339
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
340
|
+
dB = _dB.sum().to(W.dtype)
|
|
341
|
+
|
|
342
|
+
return dX, dW, dB
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class LigerPolyNormFunction(torch.autograd.Function):
|
|
346
|
+
"""
|
|
347
|
+
PolyNorm Function with forward and backward pass
|
|
348
|
+
|
|
349
|
+
PolyNorm formula:
|
|
350
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
351
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
352
|
+
|
|
353
|
+
Backward uses closed-form gradient:
|
|
354
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
@ensure_contiguous
|
|
359
|
+
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
|
|
360
|
+
"""
|
|
361
|
+
Args:
|
|
362
|
+
X: input tensor of shape (B, T, H) or (BxT, H)
|
|
363
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
364
|
+
B: bias scalar
|
|
365
|
+
eps: epsilon for numerical stability
|
|
366
|
+
in_place: whether to in-place modify grad_output in backward (saves memory)
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Y: output tensor of same shape as X
|
|
370
|
+
"""
|
|
371
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
|
|
372
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
373
|
+
ctx.num_warps = num_warps
|
|
374
|
+
ctx.in_place = in_place
|
|
375
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
376
|
+
return Y
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
@ensure_contiguous
|
|
380
|
+
def backward(ctx, grad_output):
|
|
381
|
+
"""
|
|
382
|
+
Args:
|
|
383
|
+
grad_output: gradient of output
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
dX, dW, dB: gradients w.r.t. X, W, B
|
|
387
|
+
"""
|
|
388
|
+
X, W, RSTD = ctx.saved_tensors
|
|
389
|
+
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
|
|
390
|
+
return dX, dW, dB, None, None
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _triton_qwen2vl_mrope(
|
|
8
|
+
q_ptr,
|
|
9
|
+
k_ptr,
|
|
10
|
+
cos,
|
|
11
|
+
sin,
|
|
12
|
+
sl,
|
|
13
|
+
bs: tl.constexpr,
|
|
14
|
+
n_qh: tl.constexpr,
|
|
15
|
+
n_kh: tl.constexpr,
|
|
16
|
+
hd: tl.constexpr,
|
|
17
|
+
pad_n_qh: tl.constexpr,
|
|
18
|
+
pad_n_kh: tl.constexpr,
|
|
19
|
+
pad_hd: tl.constexpr,
|
|
20
|
+
mrope_section_t: tl.constexpr,
|
|
21
|
+
mrope_section_h: tl.constexpr,
|
|
22
|
+
BLOCK_SIZE: tl.constexpr,
|
|
23
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
24
|
+
):
|
|
25
|
+
pid = tl.program_id(0)
|
|
26
|
+
|
|
27
|
+
# locate start address
|
|
28
|
+
q_ptr = q_ptr + pid * (n_qh * hd)
|
|
29
|
+
k_ptr = k_ptr + pid * (n_kh * hd)
|
|
30
|
+
|
|
31
|
+
# ####################################################################
|
|
32
|
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
|
33
|
+
# m of this program instance
|
|
34
|
+
# ####################################################################
|
|
35
|
+
|
|
36
|
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
|
37
|
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
|
38
|
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
|
39
|
+
# and pid % sl to get the sequence index.
|
|
40
|
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
41
|
+
# a clone of the left half.
|
|
42
|
+
t_end = mrope_section_t
|
|
43
|
+
h_end = t_end + mrope_section_h
|
|
44
|
+
|
|
45
|
+
t_cos = cos + pid * hd
|
|
46
|
+
h_cos = t_cos + bs * sl * hd
|
|
47
|
+
w_cos = h_cos + bs * sl * hd
|
|
48
|
+
t_sin = sin + pid * hd
|
|
49
|
+
h_sin = t_sin + bs * sl * hd
|
|
50
|
+
w_sin = h_sin + bs * sl * hd
|
|
51
|
+
|
|
52
|
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
53
|
+
t_mask = cos_offsets < t_end
|
|
54
|
+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
|
55
|
+
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
|
|
56
|
+
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
|
57
|
+
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
|
58
|
+
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
|
59
|
+
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
|
60
|
+
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
|
61
|
+
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
|
62
|
+
cos_row = t_cos_row + h_cos_row + w_cos_row
|
|
63
|
+
sin_row = t_sin_row + h_sin_row + w_sin_row
|
|
64
|
+
|
|
65
|
+
# ####################################################################
|
|
66
|
+
# Load the left and right half of q and k for the current
|
|
67
|
+
# program instance (i.e. for the current token) separately
|
|
68
|
+
# ####################################################################
|
|
69
|
+
# left half of the head
|
|
70
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
71
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
72
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
74
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
|
75
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
|
76
|
+
|
|
77
|
+
# right half of the head
|
|
78
|
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
79
|
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
80
|
+
second_q_mask = first_q_mask
|
|
81
|
+
second_k_mask = first_k_mask
|
|
82
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
|
83
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
|
84
|
+
|
|
85
|
+
if not BACKWARD_PASS:
|
|
86
|
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
87
|
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
88
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
89
|
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
90
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
91
|
+
|
|
92
|
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
93
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
94
|
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
95
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
96
|
+
else:
|
|
97
|
+
# with some math, we can get:
|
|
98
|
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
|
99
|
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
|
100
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
101
|
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
|
102
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
103
|
+
|
|
104
|
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
|
105
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
106
|
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
|
107
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
111
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
112
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
113
|
+
q = q.transpose(1, 2)
|
|
114
|
+
k = k.transpose(1, 2)
|
|
115
|
+
|
|
116
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
117
|
+
n_kv_head = k.shape[2]
|
|
118
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
119
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
120
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
121
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
122
|
+
|
|
123
|
+
n_row = batch_size * seq_len
|
|
124
|
+
|
|
125
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
126
|
+
q = q.contiguous()
|
|
127
|
+
k = k.contiguous()
|
|
128
|
+
cos = cos.contiguous()
|
|
129
|
+
sin = sin.contiguous()
|
|
130
|
+
|
|
131
|
+
_triton_qwen2vl_mrope[(n_row,)](
|
|
132
|
+
q,
|
|
133
|
+
k,
|
|
134
|
+
cos,
|
|
135
|
+
sin,
|
|
136
|
+
seq_len,
|
|
137
|
+
batch_size,
|
|
138
|
+
n_q_head,
|
|
139
|
+
n_kv_head,
|
|
140
|
+
head_dim,
|
|
141
|
+
pad_n_q_head,
|
|
142
|
+
pad_n_kv_head,
|
|
143
|
+
pad_hd,
|
|
144
|
+
mrope_section[0],
|
|
145
|
+
mrope_section[1],
|
|
146
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
147
|
+
BACKWARD_PASS=False,
|
|
148
|
+
)
|
|
149
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
153
|
+
dq = dq.transpose(1, 2)
|
|
154
|
+
dk = dk.transpose(1, 2)
|
|
155
|
+
|
|
156
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
157
|
+
n_kv_head = dk.shape[2]
|
|
158
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
159
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
160
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
161
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
162
|
+
|
|
163
|
+
n_row = batch_size * seq_len
|
|
164
|
+
|
|
165
|
+
# ensure dq and dk are contiguous
|
|
166
|
+
dq = dq.contiguous()
|
|
167
|
+
dk = dk.contiguous()
|
|
168
|
+
|
|
169
|
+
# backward is similar to forward except swapping few ops
|
|
170
|
+
_triton_qwen2vl_mrope[(n_row,)](
|
|
171
|
+
dq,
|
|
172
|
+
dk,
|
|
173
|
+
cos,
|
|
174
|
+
sin,
|
|
175
|
+
seq_len,
|
|
176
|
+
batch_size,
|
|
177
|
+
n_q_head,
|
|
178
|
+
n_kv_head,
|
|
179
|
+
head_dim,
|
|
180
|
+
pad_n_q_head,
|
|
181
|
+
pad_n_kv_head,
|
|
182
|
+
pad_hd,
|
|
183
|
+
mrope_section[0],
|
|
184
|
+
mrope_section[1],
|
|
185
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
186
|
+
BACKWARD_PASS=True,
|
|
187
|
+
)
|
|
188
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
192
|
+
"""
|
|
193
|
+
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
|
|
194
|
+
|
|
195
|
+
Please find the corresponding HuggingFace implementation here:
|
|
196
|
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
201
|
+
"""
|
|
202
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
203
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
204
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
205
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
206
|
+
"""
|
|
207
|
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
208
|
+
ctx.save_for_backward(cos, sin)
|
|
209
|
+
ctx.mrope_section = mrope_section
|
|
210
|
+
return q, k
|
|
211
|
+
|
|
212
|
+
def backward(ctx, dq, dk):
|
|
213
|
+
"""
|
|
214
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
215
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
216
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
217
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
218
|
+
"""
|
|
219
|
+
cos, sin = ctx.saved_tensors
|
|
220
|
+
mrope_section = ctx.mrope_section
|
|
221
|
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
|
222
|
+
return dq, dk, None, None, None, None
|