liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +124 -89
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +50 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +38 -6
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +31 -8
- liger_kernel/transformers/model/gemma3.py +100 -110
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1090 -116
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import math
|
|
2
1
|
import operator
|
|
3
2
|
|
|
4
3
|
import torch
|
|
@@ -43,30 +42,44 @@ def _layer_norm_forward_kernel(
|
|
|
43
42
|
https://arxiv.org/abs/1607.06450
|
|
44
43
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
45
44
|
"""
|
|
46
|
-
row_idx = tl.program_id(0)
|
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
47
46
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
48
47
|
mask = col_offsets < n_cols
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
49
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
|
50
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
51
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
|
52
|
+
W_f32 = W_row.to(tl.float32)
|
|
53
|
+
B_f32 = B_row.to(tl.float32)
|
|
54
|
+
|
|
55
|
+
# Calculate pointers for this row
|
|
56
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
|
57
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
|
58
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
|
59
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
|
60
|
+
|
|
61
|
+
# Load input data and convert to fp32 for numerical stability
|
|
62
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
|
63
|
+
X_f32 = X_row.to(tl.float32)
|
|
64
|
+
|
|
65
|
+
# Compute statistics in fp32 for numerical stability
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
67
|
+
X_centered = X_f32 - mean
|
|
68
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
69
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
62
71
|
rstd = rsqrt(var + eps)
|
|
63
72
|
|
|
64
|
-
|
|
65
|
-
tl.store(
|
|
73
|
+
# Store statistics (convert back to original dtype only once)
|
|
74
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
|
75
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
|
66
76
|
|
|
67
|
-
|
|
77
|
+
# Fused normalization and affine transformation
|
|
78
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
|
79
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
|
68
80
|
|
|
69
|
-
|
|
81
|
+
# Store output (single conversion back to original dtype)
|
|
82
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
|
70
83
|
|
|
71
84
|
|
|
72
85
|
@triton.jit
|
|
@@ -81,73 +94,86 @@ def _layer_norm_backward_kernel(
|
|
|
81
94
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
82
95
|
stride_x, # stride of each row in input
|
|
83
96
|
stride_dx, # stride of each row in input grad
|
|
84
|
-
stride_dw, # stride of each row in weights grad
|
|
85
|
-
stride_db, # stride of each row in bias grad
|
|
86
97
|
stride_dy, # stride of each row in output grad
|
|
87
|
-
n_rows,
|
|
88
98
|
n_cols,
|
|
89
|
-
rows_per_program: tl.constexpr,
|
|
90
99
|
BLOCK_SIZE: tl.constexpr,
|
|
91
100
|
dtype: tl.constexpr,
|
|
101
|
+
atomic_dtype: tl.constexpr,
|
|
92
102
|
):
|
|
93
103
|
"""
|
|
94
104
|
References:
|
|
95
105
|
https://arxiv.org/abs/1607.06450
|
|
96
106
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
97
|
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
98
|
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
|
99
107
|
"""
|
|
100
|
-
|
|
101
|
-
row_start = row_block_id * rows_per_program
|
|
102
|
-
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
108
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
103
109
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
104
110
|
mask = cols < n_cols
|
|
105
111
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
DX_ptr
|
|
113
|
-
DY_ptr
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
112
|
+
# Pre-load weights once (same optimization as forward pass)
|
|
113
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
|
+
w_f32 = w.to(tl.float32)
|
|
115
|
+
|
|
116
|
+
# Calculate pointers for this specific row
|
|
117
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
118
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
119
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
120
|
+
row_Mean_ptr = Mean_ptr + row_idx
|
|
121
|
+
row_RSTD_ptr = RSTD_ptr + row_idx
|
|
122
|
+
|
|
123
|
+
# Load data for this row
|
|
124
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
125
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
126
|
+
mean = tl.load(row_Mean_ptr)
|
|
127
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
128
|
+
|
|
129
|
+
# Convert to fp32 for numerical stability
|
|
130
|
+
x_f32 = x.to(tl.float32)
|
|
131
|
+
dy_f32 = dy.to(tl.float32)
|
|
132
|
+
mean_f32 = mean.to(tl.float32)
|
|
133
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
134
|
+
|
|
135
|
+
# Compute backward pass for this row
|
|
136
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
137
|
+
wdy = w_f32 * dy_f32
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
140
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
141
|
+
|
|
142
|
+
# Store input gradient
|
|
143
|
+
tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
|
|
144
|
+
|
|
145
|
+
# Accumulate weight and bias gradients using atomic operations
|
|
146
|
+
dw = dy_f32 * x_hat
|
|
147
|
+
db = dy_f32
|
|
148
|
+
tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
|
|
149
|
+
tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
|
|
140
150
|
|
|
141
151
|
|
|
142
152
|
def layer_norm_forward(X, W, B, eps):
|
|
153
|
+
"""
|
|
154
|
+
Args:
|
|
155
|
+
X: Input tensor of shape (..., hidden_size)
|
|
156
|
+
W: Weight tensor of shape (hidden_size,)
|
|
157
|
+
B: Bias tensor of shape (hidden_size,)
|
|
158
|
+
eps: Small constant for numerical stability
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
|
162
|
+
"""
|
|
143
163
|
shape = X.shape
|
|
144
164
|
dim = shape[-1]
|
|
145
165
|
X = X.view(-1, dim)
|
|
146
166
|
n_rows, n_cols = X.shape
|
|
167
|
+
|
|
168
|
+
# Calculate optimal block size and warp configuration
|
|
147
169
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
170
|
+
|
|
171
|
+
# Allocate output tensors
|
|
148
172
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
149
173
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
174
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
175
|
+
|
|
176
|
+
# Validate input dimensions
|
|
151
177
|
if X.shape[1] != W.shape[0]:
|
|
152
178
|
raise ValueError(
|
|
153
179
|
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
@@ -159,7 +185,9 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
159
185
|
if X.device.type == "xpu":
|
|
160
186
|
kernel_args["grf_mode"] = "large"
|
|
161
187
|
|
|
162
|
-
|
|
188
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
189
|
+
grid = (n_rows,)
|
|
190
|
+
_layer_norm_forward_kernel[grid](
|
|
163
191
|
Y,
|
|
164
192
|
Y.stride(0),
|
|
165
193
|
X,
|
|
@@ -176,35 +204,43 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
176
204
|
eps,
|
|
177
205
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
178
206
|
num_warps=num_warps,
|
|
179
|
-
**kernel_args,
|
|
207
|
+
**kernel_args,
|
|
180
208
|
)
|
|
209
|
+
|
|
181
210
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
182
211
|
|
|
183
212
|
|
|
184
213
|
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
214
|
+
"""
|
|
215
|
+
Args:
|
|
216
|
+
dY: Gradient of output
|
|
217
|
+
X: Input tensor
|
|
218
|
+
W: Weight tensor
|
|
219
|
+
B: Bias tensor
|
|
220
|
+
Mean: Pre-computed mean
|
|
221
|
+
RSTD: Pre-computed reciprocal standard deviation
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
|
225
|
+
"""
|
|
185
226
|
shape = dY.shape
|
|
186
227
|
dim = shape[-1]
|
|
187
228
|
dY = dY.view(-1, dim)
|
|
188
229
|
n_rows, n_cols = dY.shape
|
|
189
230
|
|
|
190
|
-
|
|
191
|
-
if X.device.type == "cuda":
|
|
192
|
-
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
193
|
-
elif X.device.type == "xpu":
|
|
194
|
-
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
195
|
-
|
|
231
|
+
# Allocate gradient tensors
|
|
196
232
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
197
|
-
|
|
198
|
-
|
|
233
|
+
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
|
|
234
|
+
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
|
|
235
|
+
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
236
|
+
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
199
237
|
|
|
238
|
+
# Calculate optimal block size and warp configuration
|
|
200
239
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
201
240
|
if n_cols > BLOCK_SIZE:
|
|
202
|
-
raise RuntimeError(
|
|
203
|
-
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
204
|
-
)
|
|
241
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
205
242
|
|
|
206
|
-
|
|
207
|
-
grid = (sm_count,)
|
|
243
|
+
# Determine dtype for triton operations
|
|
208
244
|
triton_dtype = (
|
|
209
245
|
tl.float32
|
|
210
246
|
if X.dtype == torch.float32
|
|
@@ -212,41 +248,40 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
212
248
|
if X.dtype == torch.bfloat16
|
|
213
249
|
else tl.float16
|
|
214
250
|
if X.dtype == torch.float16
|
|
215
|
-
else tl.float32 # fallback
|
|
251
|
+
else tl.float32 # fallback
|
|
216
252
|
)
|
|
217
253
|
|
|
254
|
+
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
+
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
256
|
+
|
|
257
|
+
kernel_args = {"num_warps": num_warps}
|
|
218
258
|
# XPU-specific optimization
|
|
219
|
-
kernel_args = {}
|
|
220
259
|
if X.device.type == "xpu":
|
|
221
260
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
222
261
|
|
|
262
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
+
grid = (n_rows,)
|
|
223
264
|
_layer_norm_backward_kernel[grid](
|
|
224
265
|
X,
|
|
225
266
|
W,
|
|
226
267
|
Mean,
|
|
227
268
|
RSTD,
|
|
228
269
|
DX,
|
|
229
|
-
|
|
230
|
-
|
|
270
|
+
DW,
|
|
271
|
+
DB,
|
|
231
272
|
dY,
|
|
232
273
|
X.stride(0),
|
|
233
274
|
DX.stride(0),
|
|
234
|
-
_DW.stride(0),
|
|
235
|
-
_DB.stride(0),
|
|
236
275
|
dY.stride(0),
|
|
237
|
-
n_rows,
|
|
238
276
|
n_cols,
|
|
239
|
-
rows_per_program,
|
|
240
277
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
241
278
|
dtype=triton_dtype,
|
|
242
|
-
|
|
279
|
+
atomic_dtype=atomic_dtype,
|
|
280
|
+
**kernel_args,
|
|
243
281
|
)
|
|
244
282
|
|
|
245
|
-
DW = _DW.sum(dim=0).to(W.dtype)
|
|
246
|
-
DB = _DB.sum(dim=0).to(W.dtype)
|
|
247
|
-
|
|
248
283
|
DX = DX.view(*shape)
|
|
249
|
-
return DX, DW, DB
|
|
284
|
+
return DX, DW.to(W.dtype), DB.to(W.dtype)
|
|
250
285
|
|
|
251
286
|
|
|
252
287
|
class LigerLayerNormFunction(torch.autograd.Function):
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
|
|
7
|
+
# Split or unpack complex frequencies into real and imag parts
|
|
8
|
+
if freqs_cis.is_complex():
|
|
9
|
+
freqs_real = freqs_cis.real
|
|
10
|
+
freqs_imag = freqs_cis.imag
|
|
11
|
+
else:
|
|
12
|
+
# Already split: last dim should be 2*head_dim_half
|
|
13
|
+
if freqs_cis.shape[-1] == 2 * head_dim_half:
|
|
14
|
+
freqs_real = freqs_cis[..., :head_dim_half]
|
|
15
|
+
freqs_imag = freqs_cis[..., head_dim_half:]
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Canonicalize to shape (seq_len, head_dim_half):
|
|
22
|
+
# 1) Ensure the last dimension is head_dim_half
|
|
23
|
+
if freqs_real.shape[-1] != head_dim_half:
|
|
24
|
+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
|
|
25
|
+
# 2) Flatten all leading dims to a single row dimension
|
|
26
|
+
freqs_real = freqs_real.reshape(-1, head_dim_half)
|
|
27
|
+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
|
|
28
|
+
# 3) If we have fewer rows than seq_len, allow broadcasting when single row
|
|
29
|
+
if freqs_real.shape[0] < seq_len:
|
|
30
|
+
if freqs_real.shape[0] == 1:
|
|
31
|
+
freqs_real = freqs_real.expand(seq_len, -1)
|
|
32
|
+
freqs_imag = freqs_imag.expand(seq_len, -1)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
|
|
35
|
+
# 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
|
|
36
|
+
elif freqs_real.shape[0] > seq_len:
|
|
37
|
+
freqs_real = freqs_real[:seq_len]
|
|
38
|
+
freqs_imag = freqs_imag[:seq_len]
|
|
39
|
+
|
|
40
|
+
return freqs_real, freqs_imag
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
44
|
+
return t if t.dtype == dtype else t.to(dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return t if t.is_contiguous() else t.contiguous()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
|
|
52
|
+
# Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
|
|
53
|
+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
|
|
54
|
+
|
|
55
|
+
# Make sure q/k share the same dtype before casting to compute dtype
|
|
56
|
+
if k.dtype != q.dtype:
|
|
57
|
+
k = k.to(q.dtype)
|
|
58
|
+
|
|
59
|
+
q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
|
|
60
|
+
k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
|
|
61
|
+
freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
|
|
62
|
+
freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
|
|
63
|
+
return q, k, freqs_real, freqs_imag
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@triton.jit
|
|
67
|
+
def _llama4_rope_kernel(
|
|
68
|
+
q_ptr,
|
|
69
|
+
k_ptr,
|
|
70
|
+
freqs_real_ptr,
|
|
71
|
+
freqs_imag_ptr,
|
|
72
|
+
q_row_stride,
|
|
73
|
+
k_row_stride,
|
|
74
|
+
q_head_stride,
|
|
75
|
+
k_head_stride,
|
|
76
|
+
freqs_row_stride,
|
|
77
|
+
seq_len,
|
|
78
|
+
batch_size,
|
|
79
|
+
imag_sign,
|
|
80
|
+
head_dim_half: tl.constexpr,
|
|
81
|
+
n_q_heads: tl.constexpr,
|
|
82
|
+
n_k_heads: tl.constexpr,
|
|
83
|
+
BLOCK_SIZE: tl.constexpr,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
|
|
87
|
+
Grid: (batch*seq, head)
|
|
88
|
+
"""
|
|
89
|
+
# 2D grid
|
|
90
|
+
pid_bs = tl.program_id(0) # over batch*seq
|
|
91
|
+
pid_h = tl.program_id(1) # over heads
|
|
92
|
+
|
|
93
|
+
batch_idx = pid_bs // seq_len
|
|
94
|
+
seq_idx = pid_bs % seq_len
|
|
95
|
+
|
|
96
|
+
# Bounds check
|
|
97
|
+
if batch_idx >= batch_size or seq_idx >= seq_len:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Base pointers for this (batch, seq) position
|
|
101
|
+
base_offset = batch_idx * seq_len + seq_idx
|
|
102
|
+
q_base = q_ptr + base_offset * q_row_stride
|
|
103
|
+
k_base = k_ptr + base_offset * k_row_stride
|
|
104
|
+
|
|
105
|
+
# Tiling over dim/2
|
|
106
|
+
for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
|
|
107
|
+
d_indices = d_start + tl.arange(0, BLOCK_SIZE)
|
|
108
|
+
mask_d = d_indices < head_dim_half
|
|
109
|
+
|
|
110
|
+
# Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
|
|
111
|
+
freq_idx = d_indices
|
|
112
|
+
freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
113
|
+
freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
114
|
+
freqs_imag = freqs_imag * imag_sign
|
|
115
|
+
|
|
116
|
+
# Process one query head per program in pid_h
|
|
117
|
+
if pid_h < n_q_heads:
|
|
118
|
+
q_head_ptr = q_base + pid_h * q_head_stride
|
|
119
|
+
q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
120
|
+
q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
121
|
+
|
|
122
|
+
# Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
|
|
123
|
+
new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
|
|
124
|
+
new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
|
|
125
|
+
|
|
126
|
+
tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
|
|
127
|
+
tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
|
|
128
|
+
|
|
129
|
+
# Process one key head per program in pid_h
|
|
130
|
+
if pid_h < n_k_heads:
|
|
131
|
+
k_head_ptr = k_base + pid_h * k_head_stride
|
|
132
|
+
k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
133
|
+
k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
134
|
+
|
|
135
|
+
new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
|
|
136
|
+
new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
|
|
137
|
+
|
|
138
|
+
tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
|
|
139
|
+
tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _select_kernel_meta(head_dim_half: int):
|
|
143
|
+
# Heuristic tuning for block size and num_warps
|
|
144
|
+
if head_dim_half >= 256:
|
|
145
|
+
return 128, 8
|
|
146
|
+
if head_dim_half >= 96:
|
|
147
|
+
return 128, 4
|
|
148
|
+
if head_dim_half >= 48:
|
|
149
|
+
return 64, 4
|
|
150
|
+
if head_dim_half >= 24:
|
|
151
|
+
return 32, 2
|
|
152
|
+
return 16, 2
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
|
|
156
|
+
# Save original dtype for casting back
|
|
157
|
+
original_dtype = q.dtype
|
|
158
|
+
|
|
159
|
+
batch_size, seq_len, n_q_heads, head_dim = q.shape
|
|
160
|
+
_, _, n_k_heads, _ = k.shape
|
|
161
|
+
head_dim_half = head_dim // 2
|
|
162
|
+
|
|
163
|
+
# Prepare frequencies
|
|
164
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
|
|
165
|
+
|
|
166
|
+
# Cast to appropriate dtype and make contiguous only when needed
|
|
167
|
+
q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
|
|
168
|
+
|
|
169
|
+
# H100-optimized meta-params
|
|
170
|
+
if BLOCK_SIZE is None:
|
|
171
|
+
BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
|
|
172
|
+
else:
|
|
173
|
+
# Provide a default num_warps if caller pins BLOCK_SIZE
|
|
174
|
+
_, num_warps = _select_kernel_meta(head_dim_half)
|
|
175
|
+
|
|
176
|
+
# 2D grid: one program per (batch, seq, head)
|
|
177
|
+
n_heads_max = max(n_q_heads, n_k_heads)
|
|
178
|
+
grid = (batch_size * seq_len, n_heads_max)
|
|
179
|
+
|
|
180
|
+
# Launch kernel
|
|
181
|
+
_llama4_rope_kernel[grid](
|
|
182
|
+
q,
|
|
183
|
+
k,
|
|
184
|
+
freqs_real,
|
|
185
|
+
freqs_imag,
|
|
186
|
+
q.stride(1),
|
|
187
|
+
k.stride(1),
|
|
188
|
+
q.stride(2),
|
|
189
|
+
k.stride(2),
|
|
190
|
+
freqs_real.stride(0),
|
|
191
|
+
seq_len,
|
|
192
|
+
batch_size,
|
|
193
|
+
imag_sign,
|
|
194
|
+
head_dim_half,
|
|
195
|
+
n_q_heads,
|
|
196
|
+
n_k_heads,
|
|
197
|
+
BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
num_stages=2,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Cast back to original dtype only if it differs from compute dtype
|
|
203
|
+
if q.dtype != original_dtype:
|
|
204
|
+
q = q.to(original_dtype)
|
|
205
|
+
if k.dtype != original_dtype:
|
|
206
|
+
k = k.to(original_dtype)
|
|
207
|
+
|
|
208
|
+
return q, k
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class LigerLlama4RopeFunction(torch.autograd.Function):
|
|
212
|
+
@staticmethod
|
|
213
|
+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
|
|
214
|
+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
|
|
215
|
+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
|
|
216
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
217
|
+
return q_out, k_out
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def backward(ctx, dq, dk):
|
|
221
|
+
(freqs_cis,) = ctx.saved_tensors
|
|
222
|
+
BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
|
|
223
|
+
# Use imag_sign=-1.0 for conjugate without materializing a new tensor
|
|
224
|
+
dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
|
|
225
|
+
return dq_out, dk_out, None
|