liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +59 -53
- liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
- liger_kernel/ops/layer_norm.py +4 -6
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/transformers/__init__.py +32 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +9 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +108 -0
- liger_kernel/transformers/model/gemma.py +2 -1
- liger_kernel/transformers/model/gemma2.py +8 -2
- liger_kernel/transformers/model/gemma3.py +27 -2
- liger_kernel/transformers/model/glm4.py +2 -1
- liger_kernel/transformers/model/glm4v.py +151 -0
- liger_kernel/transformers/model/glm4v_moe.py +153 -0
- liger_kernel/transformers/model/internvl.py +150 -0
- liger_kernel/transformers/model/llama.py +2 -1
- liger_kernel/transformers/model/llama4.py +2 -1
- liger_kernel/transformers/model/llava.py +6 -2
- liger_kernel/transformers/model/loss_utils.py +3 -0
- liger_kernel/transformers/model/mistral.py +2 -1
- liger_kernel/transformers/model/mixtral.py +8 -2
- liger_kernel/transformers/model/mllama.py +6 -3
- liger_kernel/transformers/model/olmo2.py +2 -1
- liger_kernel/transformers/model/paligemma.py +19 -0
- liger_kernel/transformers/model/phi3.py +10 -160
- liger_kernel/transformers/model/qwen2.py +2 -1
- liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
- liger_kernel/transformers/model/qwen2_vl.py +7 -2
- liger_kernel/transformers/model/qwen3.py +2 -1
- liger_kernel/transformers/model/qwen3_moe.py +8 -2
- liger_kernel/transformers/model/qwen3_next.py +134 -0
- liger_kernel/transformers/model/smollm3.py +2 -1
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +552 -23
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
|
|
7
|
+
# Split or unpack complex frequencies into real and imag parts
|
|
8
|
+
if freqs_cis.is_complex():
|
|
9
|
+
freqs_real = freqs_cis.real
|
|
10
|
+
freqs_imag = freqs_cis.imag
|
|
11
|
+
else:
|
|
12
|
+
# Already split: last dim should be 2*head_dim_half
|
|
13
|
+
if freqs_cis.shape[-1] == 2 * head_dim_half:
|
|
14
|
+
freqs_real = freqs_cis[..., :head_dim_half]
|
|
15
|
+
freqs_imag = freqs_cis[..., head_dim_half:]
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Canonicalize to shape (seq_len, head_dim_half):
|
|
22
|
+
# 1) Ensure the last dimension is head_dim_half
|
|
23
|
+
if freqs_real.shape[-1] != head_dim_half:
|
|
24
|
+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
|
|
25
|
+
# 2) Flatten all leading dims to a single row dimension
|
|
26
|
+
freqs_real = freqs_real.reshape(-1, head_dim_half)
|
|
27
|
+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
|
|
28
|
+
# 3) If we have fewer rows than seq_len, allow broadcasting when single row
|
|
29
|
+
if freqs_real.shape[0] < seq_len:
|
|
30
|
+
if freqs_real.shape[0] == 1:
|
|
31
|
+
freqs_real = freqs_real.expand(seq_len, -1)
|
|
32
|
+
freqs_imag = freqs_imag.expand(seq_len, -1)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
|
|
35
|
+
# 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
|
|
36
|
+
elif freqs_real.shape[0] > seq_len:
|
|
37
|
+
freqs_real = freqs_real[:seq_len]
|
|
38
|
+
freqs_imag = freqs_imag[:seq_len]
|
|
39
|
+
|
|
40
|
+
return freqs_real, freqs_imag
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
44
|
+
return t if t.dtype == dtype else t.to(dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return t if t.is_contiguous() else t.contiguous()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
|
|
52
|
+
# Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
|
|
53
|
+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
|
|
54
|
+
|
|
55
|
+
# Make sure q/k share the same dtype before casting to compute dtype
|
|
56
|
+
if k.dtype != q.dtype:
|
|
57
|
+
k = k.to(q.dtype)
|
|
58
|
+
|
|
59
|
+
q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
|
|
60
|
+
k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
|
|
61
|
+
freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
|
|
62
|
+
freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
|
|
63
|
+
return q, k, freqs_real, freqs_imag
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@triton.jit
|
|
67
|
+
def _llama4_rope_kernel(
|
|
68
|
+
q_ptr,
|
|
69
|
+
k_ptr,
|
|
70
|
+
freqs_real_ptr,
|
|
71
|
+
freqs_imag_ptr,
|
|
72
|
+
q_row_stride,
|
|
73
|
+
k_row_stride,
|
|
74
|
+
q_head_stride,
|
|
75
|
+
k_head_stride,
|
|
76
|
+
freqs_row_stride,
|
|
77
|
+
seq_len,
|
|
78
|
+
batch_size,
|
|
79
|
+
imag_sign,
|
|
80
|
+
head_dim_half: tl.constexpr,
|
|
81
|
+
n_q_heads: tl.constexpr,
|
|
82
|
+
n_k_heads: tl.constexpr,
|
|
83
|
+
BLOCK_SIZE: tl.constexpr,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
|
|
87
|
+
Grid: (batch*seq, head)
|
|
88
|
+
"""
|
|
89
|
+
# 2D grid
|
|
90
|
+
pid_bs = tl.program_id(0) # over batch*seq
|
|
91
|
+
pid_h = tl.program_id(1) # over heads
|
|
92
|
+
|
|
93
|
+
batch_idx = pid_bs // seq_len
|
|
94
|
+
seq_idx = pid_bs % seq_len
|
|
95
|
+
|
|
96
|
+
# Bounds check
|
|
97
|
+
if batch_idx >= batch_size or seq_idx >= seq_len:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Base pointers for this (batch, seq) position
|
|
101
|
+
base_offset = batch_idx * seq_len + seq_idx
|
|
102
|
+
q_base = q_ptr + base_offset * q_row_stride
|
|
103
|
+
k_base = k_ptr + base_offset * k_row_stride
|
|
104
|
+
|
|
105
|
+
# Tiling over dim/2
|
|
106
|
+
for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
|
|
107
|
+
d_indices = d_start + tl.arange(0, BLOCK_SIZE)
|
|
108
|
+
mask_d = d_indices < head_dim_half
|
|
109
|
+
|
|
110
|
+
# Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
|
|
111
|
+
freq_idx = d_indices
|
|
112
|
+
freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
113
|
+
freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
114
|
+
freqs_imag = freqs_imag * imag_sign
|
|
115
|
+
|
|
116
|
+
# Process one query head per program in pid_h
|
|
117
|
+
if pid_h < n_q_heads:
|
|
118
|
+
q_head_ptr = q_base + pid_h * q_head_stride
|
|
119
|
+
q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
120
|
+
q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
121
|
+
|
|
122
|
+
# Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
|
|
123
|
+
new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
|
|
124
|
+
new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
|
|
125
|
+
|
|
126
|
+
tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
|
|
127
|
+
tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
|
|
128
|
+
|
|
129
|
+
# Process one key head per program in pid_h
|
|
130
|
+
if pid_h < n_k_heads:
|
|
131
|
+
k_head_ptr = k_base + pid_h * k_head_stride
|
|
132
|
+
k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
133
|
+
k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
134
|
+
|
|
135
|
+
new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
|
|
136
|
+
new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
|
|
137
|
+
|
|
138
|
+
tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
|
|
139
|
+
tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _select_kernel_meta(head_dim_half: int):
|
|
143
|
+
# Heuristic tuning for block size and num_warps
|
|
144
|
+
if head_dim_half >= 256:
|
|
145
|
+
return 128, 8
|
|
146
|
+
if head_dim_half >= 96:
|
|
147
|
+
return 128, 4
|
|
148
|
+
if head_dim_half >= 48:
|
|
149
|
+
return 64, 4
|
|
150
|
+
if head_dim_half >= 24:
|
|
151
|
+
return 32, 2
|
|
152
|
+
return 16, 2
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
|
|
156
|
+
# Save original dtype for casting back
|
|
157
|
+
original_dtype = q.dtype
|
|
158
|
+
|
|
159
|
+
batch_size, seq_len, n_q_heads, head_dim = q.shape
|
|
160
|
+
_, _, n_k_heads, _ = k.shape
|
|
161
|
+
head_dim_half = head_dim // 2
|
|
162
|
+
|
|
163
|
+
# Prepare frequencies
|
|
164
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
|
|
165
|
+
|
|
166
|
+
# Cast to appropriate dtype and make contiguous only when needed
|
|
167
|
+
q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
|
|
168
|
+
|
|
169
|
+
# H100-optimized meta-params
|
|
170
|
+
if BLOCK_SIZE is None:
|
|
171
|
+
BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
|
|
172
|
+
else:
|
|
173
|
+
# Provide a default num_warps if caller pins BLOCK_SIZE
|
|
174
|
+
_, num_warps = _select_kernel_meta(head_dim_half)
|
|
175
|
+
|
|
176
|
+
# 2D grid: one program per (batch, seq, head)
|
|
177
|
+
n_heads_max = max(n_q_heads, n_k_heads)
|
|
178
|
+
grid = (batch_size * seq_len, n_heads_max)
|
|
179
|
+
|
|
180
|
+
# Launch kernel
|
|
181
|
+
_llama4_rope_kernel[grid](
|
|
182
|
+
q,
|
|
183
|
+
k,
|
|
184
|
+
freqs_real,
|
|
185
|
+
freqs_imag,
|
|
186
|
+
q.stride(1),
|
|
187
|
+
k.stride(1),
|
|
188
|
+
q.stride(2),
|
|
189
|
+
k.stride(2),
|
|
190
|
+
freqs_real.stride(0),
|
|
191
|
+
seq_len,
|
|
192
|
+
batch_size,
|
|
193
|
+
imag_sign,
|
|
194
|
+
head_dim_half,
|
|
195
|
+
n_q_heads,
|
|
196
|
+
n_k_heads,
|
|
197
|
+
BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
num_stages=2,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Cast back to original dtype only if it differs from compute dtype
|
|
203
|
+
if q.dtype != original_dtype:
|
|
204
|
+
q = q.to(original_dtype)
|
|
205
|
+
if k.dtype != original_dtype:
|
|
206
|
+
k = k.to(original_dtype)
|
|
207
|
+
|
|
208
|
+
return q, k
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class LigerLlama4RopeFunction(torch.autograd.Function):
|
|
212
|
+
@staticmethod
|
|
213
|
+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
|
|
214
|
+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
|
|
215
|
+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
|
|
216
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
217
|
+
return q_out, k_out
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def backward(ctx, dq, dk):
|
|
221
|
+
(freqs_cis,) = ctx.saved_tensors
|
|
222
|
+
BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
|
|
223
|
+
# Use imag_sign=-1.0 for conjugate without materializing a new tensor
|
|
224
|
+
dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
|
|
225
|
+
return dq_out, dk_out, None
|
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
try:
|
|
13
|
+
from triton.language.extra.libdevice import rsqrt
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
16
|
+
else:
|
|
17
|
+
from triton.language.math import rsqrt
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@triton.jit
|
|
21
|
+
def _poly_norm_forward_kernel(
|
|
22
|
+
Y_ptr,
|
|
23
|
+
Y_row_stride,
|
|
24
|
+
X_ptr,
|
|
25
|
+
X_row_stride,
|
|
26
|
+
W_ptr, # weight: [3] for [w0, w1, w2]
|
|
27
|
+
B_ptr, # bias: scalar
|
|
28
|
+
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
|
|
29
|
+
RSTD_row_stride,
|
|
30
|
+
n_cols,
|
|
31
|
+
eps,
|
|
32
|
+
BLOCK_SIZE: tl.constexpr,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
PolyNorm formula:
|
|
36
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
37
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
38
|
+
|
|
39
|
+
Reference:
|
|
40
|
+
1. https://github.com/BryceZhuo/PolyCom/
|
|
41
|
+
2. https://arxiv.org/pdf/2411.03884
|
|
42
|
+
|
|
43
|
+
Cache rstd values for backward pass
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
46
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = col_offsets < n_cols
|
|
48
|
+
|
|
49
|
+
# Load pointers
|
|
50
|
+
Y_ptr += row_idx * Y_row_stride
|
|
51
|
+
X_ptr += row_idx * X_row_stride
|
|
52
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
53
|
+
|
|
54
|
+
# Load input row
|
|
55
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
56
|
+
|
|
57
|
+
# Load weights and bias
|
|
58
|
+
w0 = tl.load(W_ptr + 0)
|
|
59
|
+
w1 = tl.load(W_ptr + 1)
|
|
60
|
+
w2 = tl.load(W_ptr + 2)
|
|
61
|
+
b = tl.load(B_ptr)
|
|
62
|
+
|
|
63
|
+
# Compute x³, x², x
|
|
64
|
+
X_pow3 = X_row * X_row * X_row
|
|
65
|
+
X_pow2 = X_row * X_row
|
|
66
|
+
X_pow1 = X_row
|
|
67
|
+
|
|
68
|
+
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
|
|
69
|
+
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
|
|
70
|
+
rstd_3 = rsqrt(mean_square_3 + eps)
|
|
71
|
+
norm_x3 = X_pow3 * rstd_3
|
|
72
|
+
|
|
73
|
+
# Compute norm(x²)
|
|
74
|
+
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
|
|
75
|
+
rstd_2 = rsqrt(mean_square_2 + eps)
|
|
76
|
+
norm_x2 = X_pow2 * rstd_2
|
|
77
|
+
|
|
78
|
+
# Compute norm(x)
|
|
79
|
+
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
|
|
80
|
+
rstd_1 = rsqrt(mean_square_1 + eps)
|
|
81
|
+
norm_x1 = X_pow1 * rstd_1
|
|
82
|
+
|
|
83
|
+
# Cache rstd values for backward
|
|
84
|
+
tl.store(RSTD_ptr + 0, rstd_3)
|
|
85
|
+
tl.store(RSTD_ptr + 1, rstd_2)
|
|
86
|
+
tl.store(RSTD_ptr + 2, rstd_1)
|
|
87
|
+
|
|
88
|
+
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
89
|
+
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
|
|
90
|
+
|
|
91
|
+
# Store output
|
|
92
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton.jit
|
|
96
|
+
def _poly_norm_backward_kernel(
|
|
97
|
+
dY_ptr,
|
|
98
|
+
dY_row_stride,
|
|
99
|
+
dX_ptr,
|
|
100
|
+
dX_row_stride,
|
|
101
|
+
X_ptr,
|
|
102
|
+
X_row_stride,
|
|
103
|
+
W_ptr,
|
|
104
|
+
RSTD_ptr,
|
|
105
|
+
RSTD_row_stride,
|
|
106
|
+
dW_ptr, # shape: (n_programs, 3)
|
|
107
|
+
dW_row_stride,
|
|
108
|
+
dB_ptr, # shape: (n_programs,)
|
|
109
|
+
n_rows,
|
|
110
|
+
n_cols,
|
|
111
|
+
rows_per_program: tl.constexpr,
|
|
112
|
+
BLOCK_SIZE: tl.constexpr,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
PolyNorm Backward Kernel Gradient:
|
|
116
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
117
|
+
|
|
118
|
+
where:
|
|
119
|
+
- D_p = RMS(x^p) = 1/rstd_p
|
|
120
|
+
- S_p = sum(grad * x^p) over the row
|
|
121
|
+
- d = n_cols
|
|
122
|
+
- p ∈ {3, 2, 1}
|
|
123
|
+
"""
|
|
124
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
125
|
+
row_start = row_block_id * rows_per_program
|
|
126
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
127
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
128
|
+
mask = col_offsets < n_cols
|
|
129
|
+
|
|
130
|
+
# Initialize accumulators for weight and bias gradients (scalars)
|
|
131
|
+
dW0_acc = 0.0
|
|
132
|
+
dW1_acc = 0.0
|
|
133
|
+
dW2_acc = 0.0
|
|
134
|
+
dB_acc = 0.0
|
|
135
|
+
|
|
136
|
+
# Load weights
|
|
137
|
+
w0 = tl.load(W_ptr + 0).to(tl.float32)
|
|
138
|
+
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
139
|
+
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
140
|
+
|
|
141
|
+
dY_ptr += row_start * dY_row_stride
|
|
142
|
+
dX_ptr += row_start * dX_row_stride
|
|
143
|
+
X_ptr += row_start * X_row_stride
|
|
144
|
+
RSTD_ptr += row_start * RSTD_row_stride
|
|
145
|
+
|
|
146
|
+
for _ in range(row_start, row_end):
|
|
147
|
+
# Load input and gradient
|
|
148
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
149
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
|
+
|
|
151
|
+
# Load cached rstd values
|
|
152
|
+
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
|
|
153
|
+
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
|
|
154
|
+
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
|
|
155
|
+
|
|
156
|
+
# Compute powers
|
|
157
|
+
X_pow3 = X_row * X_row * X_row
|
|
158
|
+
X_pow2 = X_row * X_row
|
|
159
|
+
X_pow1 = X_row
|
|
160
|
+
|
|
161
|
+
# Accumulate bias gradient: dB = sum(dY)
|
|
162
|
+
dB_acc += tl.sum(dY_row, axis=0)
|
|
163
|
+
|
|
164
|
+
# Compute gradient w.r.t. input using closed-form formula
|
|
165
|
+
# For p=3: ∂L/∂x from w0 * norm(x³)
|
|
166
|
+
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
|
|
167
|
+
grad_x_3 = w0 * (
|
|
168
|
+
3.0 * X_pow2 * rstd_3 * dY_row
|
|
169
|
+
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# For p=2: ∂L/∂x from w1 * norm(x²)
|
|
173
|
+
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
|
|
174
|
+
grad_x_2 = w1 * (
|
|
175
|
+
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# For p=1: ∂L/∂x from w2 * norm(x)
|
|
179
|
+
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
|
|
180
|
+
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
|
|
181
|
+
|
|
182
|
+
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
|
|
183
|
+
dW0_acc += rstd_3 * S_3
|
|
184
|
+
dW1_acc += rstd_2 * S_2
|
|
185
|
+
dW2_acc += rstd_1 * S_1
|
|
186
|
+
|
|
187
|
+
# Total gradient
|
|
188
|
+
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
189
|
+
|
|
190
|
+
# Store gradient
|
|
191
|
+
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
|
|
192
|
+
|
|
193
|
+
# Update pointers
|
|
194
|
+
dY_ptr += dY_row_stride
|
|
195
|
+
dX_ptr += dX_row_stride
|
|
196
|
+
X_ptr += X_row_stride
|
|
197
|
+
RSTD_ptr += RSTD_row_stride
|
|
198
|
+
|
|
199
|
+
# Store accumulated gradients (scalars)
|
|
200
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
201
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
|
|
202
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
|
|
203
|
+
tl.store(dB_ptr + row_block_id, dB_acc)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def poly_norm_forward(X, W, B, eps=1e-6):
|
|
207
|
+
"""
|
|
208
|
+
PolyNorm Forward Pass
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
X: input tensor of shape (*, H) where H is hidden dimension
|
|
212
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
213
|
+
B: bias scalar tensor
|
|
214
|
+
eps: epsilon for numerical stability
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Y: output tensor of same shape as X
|
|
218
|
+
X: reshaped input (for backward)
|
|
219
|
+
RSTD: cached rstd values (for backward)
|
|
220
|
+
BLOCK_SIZE: block size used
|
|
221
|
+
num_warps: number of warps used
|
|
222
|
+
"""
|
|
223
|
+
shape = X.shape
|
|
224
|
+
dim = shape[-1]
|
|
225
|
+
X = X.view(-1, dim)
|
|
226
|
+
n_rows, n_cols = X.shape
|
|
227
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
228
|
+
|
|
229
|
+
# RSTD is to cache rstd for each row
|
|
230
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
231
|
+
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
|
|
232
|
+
|
|
233
|
+
# Check constraints
|
|
234
|
+
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
|
|
235
|
+
assert B.numel() == 1, "Bias must be a scalar"
|
|
236
|
+
|
|
237
|
+
# XPU-specific optimization
|
|
238
|
+
kernel_args = {}
|
|
239
|
+
if X.device.type == "xpu":
|
|
240
|
+
kernel_args["grf_mode"] = "large"
|
|
241
|
+
|
|
242
|
+
# Launch kernel
|
|
243
|
+
_poly_norm_forward_kernel[(n_rows,)](
|
|
244
|
+
Y,
|
|
245
|
+
Y.stride(0),
|
|
246
|
+
X,
|
|
247
|
+
X.stride(0),
|
|
248
|
+
W,
|
|
249
|
+
B,
|
|
250
|
+
RSTD,
|
|
251
|
+
RSTD.stride(0),
|
|
252
|
+
n_cols,
|
|
253
|
+
eps,
|
|
254
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
255
|
+
num_warps=num_warps,
|
|
256
|
+
**kernel_args,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
263
|
+
"""
|
|
264
|
+
PolyNorm Backward Pass
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
dY: gradient of output
|
|
268
|
+
X: input tensor (already reshaped to 2D)
|
|
269
|
+
W: weight tensor
|
|
270
|
+
RSTD: cached rstd values from forward
|
|
271
|
+
BLOCK_SIZE: block size from forward
|
|
272
|
+
num_warps: number of warps from forward
|
|
273
|
+
in_place: whether to in-place modify dY to store dX (saves memory)
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
dX: gradient w.r.t. input
|
|
277
|
+
dW: gradient w.r.t. weight
|
|
278
|
+
dB: gradient w.r.t. bias
|
|
279
|
+
"""
|
|
280
|
+
shape = dY.shape
|
|
281
|
+
dim = shape[-1]
|
|
282
|
+
dY = dY.view(-1, dim)
|
|
283
|
+
n_rows, n_cols = dY.shape
|
|
284
|
+
|
|
285
|
+
# Get number of SMs for parallelization
|
|
286
|
+
import math
|
|
287
|
+
|
|
288
|
+
sm_count = 1
|
|
289
|
+
if X.device.type == "cuda":
|
|
290
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
291
|
+
elif X.device.type == "xpu":
|
|
292
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
293
|
+
|
|
294
|
+
# Allocate or reuse gradients
|
|
295
|
+
if in_place is True:
|
|
296
|
+
dX = dY
|
|
297
|
+
else:
|
|
298
|
+
dX = torch.zeros_like(dY)
|
|
299
|
+
|
|
300
|
+
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
|
|
301
|
+
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
|
|
302
|
+
|
|
303
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
304
|
+
grid = (sm_count,)
|
|
305
|
+
|
|
306
|
+
# XPU-specific optimization
|
|
307
|
+
kernel_args = {}
|
|
308
|
+
if X.device.type == "xpu":
|
|
309
|
+
kernel_args["grf_mode"] = "large"
|
|
310
|
+
|
|
311
|
+
# Launch backward kernel
|
|
312
|
+
_poly_norm_backward_kernel[grid](
|
|
313
|
+
dY,
|
|
314
|
+
dY.stride(0),
|
|
315
|
+
dX,
|
|
316
|
+
dX.stride(0),
|
|
317
|
+
X,
|
|
318
|
+
X.stride(0),
|
|
319
|
+
W,
|
|
320
|
+
RSTD,
|
|
321
|
+
RSTD.stride(0),
|
|
322
|
+
_dW,
|
|
323
|
+
_dW.stride(0),
|
|
324
|
+
_dB,
|
|
325
|
+
n_rows,
|
|
326
|
+
n_cols,
|
|
327
|
+
rows_per_program,
|
|
328
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
329
|
+
num_warps=num_warps,
|
|
330
|
+
**kernel_args,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Reduce gradients across SMs
|
|
334
|
+
dX = dX.view(*shape)
|
|
335
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
336
|
+
dB = _dB.sum().to(W.dtype)
|
|
337
|
+
|
|
338
|
+
return dX, dW, dB
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class LigerPolyNormFunction(torch.autograd.Function):
|
|
342
|
+
"""
|
|
343
|
+
PolyNorm Function with forward and backward pass
|
|
344
|
+
|
|
345
|
+
PolyNorm formula:
|
|
346
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
347
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
348
|
+
|
|
349
|
+
Backward uses closed-form gradient:
|
|
350
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
@staticmethod
|
|
354
|
+
@ensure_contiguous
|
|
355
|
+
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
|
|
356
|
+
"""
|
|
357
|
+
Args:
|
|
358
|
+
X: input tensor of shape (B, T, H) or (BxT, H)
|
|
359
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
360
|
+
B: bias scalar
|
|
361
|
+
eps: epsilon for numerical stability
|
|
362
|
+
in_place: whether to in-place modify grad_output in backward (saves memory)
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Y: output tensor of same shape as X
|
|
366
|
+
"""
|
|
367
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
|
|
368
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
369
|
+
ctx.num_warps = num_warps
|
|
370
|
+
ctx.in_place = in_place
|
|
371
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
372
|
+
return Y
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
@ensure_contiguous
|
|
376
|
+
def backward(ctx, grad_output):
|
|
377
|
+
"""
|
|
378
|
+
Args:
|
|
379
|
+
grad_output: gradient of output
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
dX, dW, dB: gradients w.r.t. X, W, B
|
|
383
|
+
"""
|
|
384
|
+
X, W, RSTD = ctx.saved_tensors
|
|
385
|
+
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
|
|
386
|
+
return dX, dW, dB, None, None
|