liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
|
|
|
40
41
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
41
42
|
tanh_result = tanh(tanh_arg)
|
|
42
43
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
43
|
-
c_row = geglu_a * b_row
|
|
44
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
44
45
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
45
46
|
|
|
46
47
|
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
66
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
67
68
|
tanh_result = tanh(tanh_arg)
|
|
68
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
69
71
|
|
|
70
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
71
73
|
|
|
72
74
|
# Gradient w.r.t. a can be computed with:
|
|
73
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
78
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
79
81
|
|
|
80
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
81
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def geglu_forward(a, b):
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -77,15 +78,14 @@ def _group_norm_forward_kernel(
|
|
|
77
78
|
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
78
79
|
W = tl.load(W_ptr + channel_idx)
|
|
79
80
|
B = tl.load(B_ptr + channel_idx)
|
|
80
|
-
|
|
81
|
+
# Calculate channel offset within the group
|
|
82
|
+
channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
|
|
83
|
+
for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
81
84
|
hidden_size_offsets = i + block_range
|
|
82
85
|
mask = hidden_size_offsets < hidden_size_per_channel
|
|
83
|
-
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
86
|
+
X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
|
|
84
87
|
Y = (X - m) * rstd * W + B
|
|
85
|
-
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
86
|
-
|
|
87
|
-
X_ptr += hidden_size_per_channel
|
|
88
|
-
Y_ptr += hidden_size_per_channel
|
|
88
|
+
tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
|
|
89
89
|
|
|
90
90
|
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
91
91
|
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
|
21
21
|
return num_warps
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
if infer_device() == "xpu":
|
|
25
|
+
MAX_FUSED_SIZE = 8192
|
|
26
|
+
elif infer_device() == "npu":
|
|
27
|
+
MAX_FUSED_SIZE = 8192
|
|
28
|
+
else:
|
|
29
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
25
30
|
|
|
26
31
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
27
32
|
|
|
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
|
|
|
116
121
|
|
|
117
122
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
118
123
|
BT, V = y_pred.shape
|
|
119
|
-
BLOCK_SIZE = (
|
|
120
|
-
min(8192, triton.next_power_of_2(V))
|
|
121
|
-
if infer_device() == "xpu"
|
|
122
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
-
)
|
|
124
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
124
125
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
125
126
|
|
|
126
127
|
grid = (BT,)
|
|
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
159
160
|
|
|
160
161
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
161
162
|
BT, V = target.shape
|
|
162
|
-
BLOCK_SIZE = (
|
|
163
|
-
min(8192, triton.next_power_of_2(V))
|
|
164
|
-
if infer_device() == "xpu"
|
|
165
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
-
)
|
|
163
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
167
164
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
168
165
|
|
|
169
166
|
grid = (BT,)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -8,8 +8,9 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
11
12
|
|
|
12
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
13
14
|
try:
|
|
14
15
|
# typical import path with dispatch available
|
|
15
16
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -43,111 +44,157 @@ def _layer_norm_forward_kernel(
|
|
|
43
44
|
https://arxiv.org/abs/1607.06450
|
|
44
45
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
45
46
|
"""
|
|
46
|
-
row_idx = tl.program_id(0)
|
|
47
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
47
48
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
48
49
|
mask = col_offsets < n_cols
|
|
49
50
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
51
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
|
52
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
53
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
|
54
|
+
W_f32 = W_row.to(tl.float32)
|
|
55
|
+
B_f32 = B_row.to(tl.float32)
|
|
56
|
+
|
|
57
|
+
# Calculate pointers for this row
|
|
58
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
|
59
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
|
60
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
|
61
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
|
62
|
+
|
|
63
|
+
# Load input data and convert to fp32 for numerical stability
|
|
64
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
|
65
|
+
X_f32 = X_row.to(tl.float32)
|
|
66
|
+
|
|
67
|
+
# Compute statistics in fp32 for numerical stability
|
|
68
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
69
|
+
X_centered = X_f32 - mean
|
|
70
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
71
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
72
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
62
73
|
rstd = rsqrt(var + eps)
|
|
63
74
|
|
|
64
|
-
|
|
65
|
-
tl.store(
|
|
75
|
+
# Store statistics (convert back to original dtype only once)
|
|
76
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
|
77
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
|
66
78
|
|
|
67
|
-
|
|
79
|
+
# Fused normalization and affine transformation
|
|
80
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
|
81
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
|
68
82
|
|
|
69
|
-
|
|
83
|
+
# Store output (single conversion back to original dtype)
|
|
84
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
|
70
85
|
|
|
71
86
|
|
|
72
87
|
@triton.jit
|
|
73
88
|
def _layer_norm_backward_kernel(
|
|
74
89
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
90
|
+
stride_x, # stride of each row in input
|
|
75
91
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
76
92
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
93
|
+
stride_mean, # stride of each row in mean
|
|
77
94
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
95
|
+
stride_rstd, # stride of each row in rstd
|
|
78
96
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
79
|
-
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
80
|
-
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
81
|
-
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
82
|
-
stride_x, # stride of each row in input
|
|
83
97
|
stride_dx, # stride of each row in input grad
|
|
98
|
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
84
99
|
stride_dw, # stride of each row in weights grad
|
|
100
|
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
85
101
|
stride_db, # stride of each row in bias grad
|
|
102
|
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
86
103
|
stride_dy, # stride of each row in output grad
|
|
87
104
|
n_rows,
|
|
88
105
|
n_cols,
|
|
89
106
|
rows_per_program: tl.constexpr,
|
|
90
107
|
BLOCK_SIZE: tl.constexpr,
|
|
91
|
-
dtype: tl.constexpr,
|
|
92
108
|
):
|
|
93
109
|
"""
|
|
94
110
|
References:
|
|
95
111
|
https://arxiv.org/abs/1607.06450
|
|
96
112
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
97
|
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
98
|
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
|
99
113
|
"""
|
|
100
|
-
row_block_id = tl.program_id(0)
|
|
114
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
101
115
|
row_start = row_block_id * rows_per_program
|
|
102
116
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
103
117
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
104
118
|
mask = cols < n_cols
|
|
105
119
|
|
|
106
|
-
|
|
120
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
121
|
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
108
122
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
123
|
+
# Pre-load weights once (same optimization as forward pass)
|
|
124
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
125
|
+
w_f32 = w.to(tl.float32)
|
|
126
|
+
|
|
127
|
+
# Calculate pointers for this specific row
|
|
128
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
129
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
130
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
131
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
132
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
114
133
|
|
|
115
134
|
for _ in range(row_start, row_end):
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
dy = tl.load(
|
|
119
|
-
mean = tl.load(
|
|
120
|
-
rstd = tl.load(
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
135
|
+
# Load data for this row
|
|
136
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
138
|
+
mean = tl.load(row_Mean_ptr)
|
|
139
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
140
|
+
|
|
141
|
+
# Convert to fp32 for numerical stability
|
|
142
|
+
x_f32 = x.to(tl.float32)
|
|
143
|
+
dy_f32 = dy.to(tl.float32)
|
|
144
|
+
mean_f32 = mean.to(tl.float32)
|
|
145
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
146
|
+
|
|
147
|
+
# Compute backward pass for this row
|
|
148
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
149
|
+
wdy = w_f32 * dy_f32
|
|
124
150
|
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
125
151
|
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
126
|
-
dx = (wdy - (x_hat * c1 + c2)) *
|
|
127
|
-
|
|
152
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
153
|
+
|
|
154
|
+
# Store input gradient
|
|
155
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
128
156
|
|
|
129
|
-
|
|
130
|
-
|
|
157
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
158
|
+
dw = dy_f32 * x_hat
|
|
159
|
+
db = dy_f32
|
|
160
|
+
dW_row += dw
|
|
161
|
+
db_row += db
|
|
131
162
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
163
|
+
row_X_ptr += stride_x
|
|
164
|
+
row_DX_ptr += stride_dx
|
|
165
|
+
row_DY_ptr += stride_dy
|
|
166
|
+
row_Mean_ptr += stride_mean
|
|
167
|
+
row_RSTD_ptr += stride_rstd
|
|
137
168
|
|
|
138
|
-
tl.store(DW_ptr + row_block_id * stride_dw + cols,
|
|
139
|
-
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row
|
|
169
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
140
171
|
|
|
141
172
|
|
|
142
173
|
def layer_norm_forward(X, W, B, eps):
|
|
174
|
+
"""
|
|
175
|
+
Args:
|
|
176
|
+
X: Input tensor of shape (..., hidden_size)
|
|
177
|
+
W: Weight tensor of shape (hidden_size,)
|
|
178
|
+
B: Bias tensor of shape (hidden_size,)
|
|
179
|
+
eps: Small constant for numerical stability
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
|
183
|
+
"""
|
|
143
184
|
shape = X.shape
|
|
144
185
|
dim = shape[-1]
|
|
145
186
|
X = X.view(-1, dim)
|
|
146
187
|
n_rows, n_cols = X.shape
|
|
188
|
+
|
|
189
|
+
# Calculate optimal block size and warp configuration
|
|
147
190
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
191
|
+
|
|
192
|
+
# Allocate output tensors
|
|
148
193
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
149
194
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
195
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
196
|
+
|
|
197
|
+
# Validate input dimensions
|
|
151
198
|
if X.shape[1] != W.shape[0]:
|
|
152
199
|
raise ValueError(
|
|
153
200
|
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
@@ -159,7 +206,9 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
159
206
|
if X.device.type == "xpu":
|
|
160
207
|
kernel_args["grf_mode"] = "large"
|
|
161
208
|
|
|
162
|
-
|
|
209
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
210
|
+
grid = (n_rows,)
|
|
211
|
+
_layer_norm_forward_kernel[grid](
|
|
163
212
|
Y,
|
|
164
213
|
Y.stride(0),
|
|
165
214
|
X,
|
|
@@ -176,12 +225,25 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
176
225
|
eps,
|
|
177
226
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
178
227
|
num_warps=num_warps,
|
|
179
|
-
**kernel_args,
|
|
228
|
+
**kernel_args,
|
|
180
229
|
)
|
|
230
|
+
|
|
181
231
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
182
232
|
|
|
183
233
|
|
|
184
234
|
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
235
|
+
"""
|
|
236
|
+
Args:
|
|
237
|
+
dY: Gradient of output
|
|
238
|
+
X: Input tensor
|
|
239
|
+
W: Weight tensor
|
|
240
|
+
B: Bias tensor
|
|
241
|
+
Mean: Pre-computed mean
|
|
242
|
+
RSTD: Pre-computed reciprocal standard deviation
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
|
246
|
+
"""
|
|
185
247
|
shape = dY.shape
|
|
186
248
|
dim = shape[-1]
|
|
187
249
|
dY = dY.view(-1, dim)
|
|
@@ -193,59 +255,52 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
193
255
|
elif X.device.type == "xpu":
|
|
194
256
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
195
257
|
|
|
196
|
-
|
|
197
|
-
_DW = torch.empty((sm_count, n_cols), dtype=
|
|
198
|
-
_DB = torch.empty((sm_count, n_cols), dtype=
|
|
258
|
+
# fp32 for numerical stability especially.
|
|
259
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
260
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
199
261
|
|
|
262
|
+
# Calculate optimal block size and warp configuration
|
|
200
263
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
201
264
|
if n_cols > BLOCK_SIZE:
|
|
202
|
-
raise RuntimeError(
|
|
203
|
-
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
204
|
-
)
|
|
205
|
-
|
|
265
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
206
266
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
207
267
|
grid = (sm_count,)
|
|
208
|
-
triton_dtype = (
|
|
209
|
-
tl.float32
|
|
210
|
-
if X.dtype == torch.float32
|
|
211
|
-
else tl.bfloat16
|
|
212
|
-
if X.dtype == torch.bfloat16
|
|
213
|
-
else tl.float16
|
|
214
|
-
if X.dtype == torch.float16
|
|
215
|
-
else tl.float32 # fallback to float32 for other types
|
|
216
|
-
)
|
|
217
268
|
|
|
269
|
+
# Allocate gradient tensors
|
|
270
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
271
|
+
|
|
272
|
+
kernel_args = {"num_warps": num_warps}
|
|
218
273
|
# XPU-specific optimization
|
|
219
|
-
kernel_args = {}
|
|
220
274
|
if X.device.type == "xpu":
|
|
221
275
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
222
276
|
|
|
277
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
223
278
|
_layer_norm_backward_kernel[grid](
|
|
224
279
|
X,
|
|
280
|
+
X.stride(0),
|
|
225
281
|
W,
|
|
226
282
|
Mean,
|
|
283
|
+
Mean.stride(0),
|
|
227
284
|
RSTD,
|
|
285
|
+
RSTD.stride(0),
|
|
228
286
|
DX,
|
|
229
|
-
_DW,
|
|
230
|
-
_DB,
|
|
231
|
-
dY,
|
|
232
|
-
X.stride(0),
|
|
233
287
|
DX.stride(0),
|
|
288
|
+
_DW,
|
|
234
289
|
_DW.stride(0),
|
|
290
|
+
_DB,
|
|
235
291
|
_DB.stride(0),
|
|
292
|
+
dY,
|
|
236
293
|
dY.stride(0),
|
|
237
294
|
n_rows,
|
|
238
295
|
n_cols,
|
|
239
|
-
rows_per_program,
|
|
296
|
+
rows_per_program=rows_per_program,
|
|
240
297
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
241
|
-
|
|
242
|
-
**kernel_args, # XPU-specific optimization
|
|
298
|
+
**kernel_args,
|
|
243
299
|
)
|
|
244
300
|
|
|
245
|
-
DW = _DW.sum(dim=0).to(W.dtype)
|
|
246
|
-
DB = _DB.sum(dim=0).to(W.dtype)
|
|
247
|
-
|
|
248
301
|
DX = DX.view(*shape)
|
|
302
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
249
304
|
return DX, DW, DB
|
|
250
305
|
|
|
251
306
|
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
|
|
7
|
+
# Split or unpack complex frequencies into real and imag parts
|
|
8
|
+
if freqs_cis.is_complex():
|
|
9
|
+
freqs_real = freqs_cis.real
|
|
10
|
+
freqs_imag = freqs_cis.imag
|
|
11
|
+
else:
|
|
12
|
+
# Already split: last dim should be 2*head_dim_half
|
|
13
|
+
if freqs_cis.shape[-1] == 2 * head_dim_half:
|
|
14
|
+
freqs_real = freqs_cis[..., :head_dim_half]
|
|
15
|
+
freqs_imag = freqs_cis[..., head_dim_half:]
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Canonicalize to shape (seq_len, head_dim_half):
|
|
22
|
+
# 1) Ensure the last dimension is head_dim_half
|
|
23
|
+
if freqs_real.shape[-1] != head_dim_half:
|
|
24
|
+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
|
|
25
|
+
# 2) Flatten all leading dims to a single row dimension
|
|
26
|
+
freqs_real = freqs_real.reshape(-1, head_dim_half)
|
|
27
|
+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
|
|
28
|
+
# 3) If we have fewer rows than seq_len, allow broadcasting when single row
|
|
29
|
+
if freqs_real.shape[0] < seq_len:
|
|
30
|
+
if freqs_real.shape[0] == 1:
|
|
31
|
+
freqs_real = freqs_real.expand(seq_len, -1)
|
|
32
|
+
freqs_imag = freqs_imag.expand(seq_len, -1)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
|
|
35
|
+
# 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
|
|
36
|
+
elif freqs_real.shape[0] > seq_len:
|
|
37
|
+
freqs_real = freqs_real[:seq_len]
|
|
38
|
+
freqs_imag = freqs_imag[:seq_len]
|
|
39
|
+
|
|
40
|
+
return freqs_real, freqs_imag
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
44
|
+
return t if t.dtype == dtype else t.to(dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return t if t.is_contiguous() else t.contiguous()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
|
|
52
|
+
# Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
|
|
53
|
+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
|
|
54
|
+
|
|
55
|
+
# Make sure q/k share the same dtype before casting to compute dtype
|
|
56
|
+
if k.dtype != q.dtype:
|
|
57
|
+
k = k.to(q.dtype)
|
|
58
|
+
|
|
59
|
+
q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
|
|
60
|
+
k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
|
|
61
|
+
freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
|
|
62
|
+
freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
|
|
63
|
+
return q, k, freqs_real, freqs_imag
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@triton.jit
|
|
67
|
+
def _llama4_rope_kernel(
|
|
68
|
+
q_ptr,
|
|
69
|
+
k_ptr,
|
|
70
|
+
freqs_real_ptr,
|
|
71
|
+
freqs_imag_ptr,
|
|
72
|
+
q_row_stride,
|
|
73
|
+
k_row_stride,
|
|
74
|
+
q_head_stride,
|
|
75
|
+
k_head_stride,
|
|
76
|
+
freqs_row_stride,
|
|
77
|
+
seq_len,
|
|
78
|
+
batch_size,
|
|
79
|
+
imag_sign,
|
|
80
|
+
head_dim_half: tl.constexpr,
|
|
81
|
+
n_q_heads: tl.constexpr,
|
|
82
|
+
n_k_heads: tl.constexpr,
|
|
83
|
+
BLOCK_SIZE: tl.constexpr,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
|
|
87
|
+
Grid: (batch*seq, head)
|
|
88
|
+
"""
|
|
89
|
+
# 2D grid
|
|
90
|
+
pid_bs = tl.program_id(0) # over batch*seq
|
|
91
|
+
pid_h = tl.program_id(1) # over heads
|
|
92
|
+
|
|
93
|
+
batch_idx = pid_bs // seq_len
|
|
94
|
+
seq_idx = pid_bs % seq_len
|
|
95
|
+
|
|
96
|
+
# Bounds check
|
|
97
|
+
if batch_idx >= batch_size or seq_idx >= seq_len:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Base pointers for this (batch, seq) position
|
|
101
|
+
base_offset = batch_idx * seq_len + seq_idx
|
|
102
|
+
q_base = q_ptr + base_offset * q_row_stride
|
|
103
|
+
k_base = k_ptr + base_offset * k_row_stride
|
|
104
|
+
|
|
105
|
+
# Tiling over dim/2
|
|
106
|
+
for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
|
|
107
|
+
d_indices = d_start + tl.arange(0, BLOCK_SIZE)
|
|
108
|
+
mask_d = d_indices < head_dim_half
|
|
109
|
+
|
|
110
|
+
# Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
|
|
111
|
+
freq_idx = d_indices
|
|
112
|
+
freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
113
|
+
freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
|
|
114
|
+
freqs_imag = freqs_imag * imag_sign
|
|
115
|
+
|
|
116
|
+
# Process one query head per program in pid_h
|
|
117
|
+
if pid_h < n_q_heads:
|
|
118
|
+
q_head_ptr = q_base + pid_h * q_head_stride
|
|
119
|
+
q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
120
|
+
q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
121
|
+
|
|
122
|
+
# Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
|
|
123
|
+
new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
|
|
124
|
+
new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
|
|
125
|
+
|
|
126
|
+
tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
|
|
127
|
+
tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
|
|
128
|
+
|
|
129
|
+
# Process one key head per program in pid_h
|
|
130
|
+
if pid_h < n_k_heads:
|
|
131
|
+
k_head_ptr = k_base + pid_h * k_head_stride
|
|
132
|
+
k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
|
|
133
|
+
k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
|
|
134
|
+
|
|
135
|
+
new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
|
|
136
|
+
new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
|
|
137
|
+
|
|
138
|
+
tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
|
|
139
|
+
tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _select_kernel_meta(head_dim_half: int):
|
|
143
|
+
# Heuristic tuning for block size and num_warps
|
|
144
|
+
if head_dim_half >= 256:
|
|
145
|
+
return 128, 8
|
|
146
|
+
if head_dim_half >= 96:
|
|
147
|
+
return 128, 4
|
|
148
|
+
if head_dim_half >= 48:
|
|
149
|
+
return 64, 4
|
|
150
|
+
if head_dim_half >= 24:
|
|
151
|
+
return 32, 2
|
|
152
|
+
return 16, 2
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
|
|
156
|
+
# Save original dtype for casting back
|
|
157
|
+
original_dtype = q.dtype
|
|
158
|
+
|
|
159
|
+
batch_size, seq_len, n_q_heads, head_dim = q.shape
|
|
160
|
+
_, _, n_k_heads, _ = k.shape
|
|
161
|
+
head_dim_half = head_dim // 2
|
|
162
|
+
|
|
163
|
+
# Prepare frequencies
|
|
164
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
|
|
165
|
+
|
|
166
|
+
# Cast to appropriate dtype and make contiguous only when needed
|
|
167
|
+
q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
|
|
168
|
+
|
|
169
|
+
# H100-optimized meta-params
|
|
170
|
+
if BLOCK_SIZE is None:
|
|
171
|
+
BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
|
|
172
|
+
else:
|
|
173
|
+
# Provide a default num_warps if caller pins BLOCK_SIZE
|
|
174
|
+
_, num_warps = _select_kernel_meta(head_dim_half)
|
|
175
|
+
|
|
176
|
+
# 2D grid: one program per (batch, seq, head)
|
|
177
|
+
n_heads_max = max(n_q_heads, n_k_heads)
|
|
178
|
+
grid = (batch_size * seq_len, n_heads_max)
|
|
179
|
+
|
|
180
|
+
# Launch kernel
|
|
181
|
+
_llama4_rope_kernel[grid](
|
|
182
|
+
q,
|
|
183
|
+
k,
|
|
184
|
+
freqs_real,
|
|
185
|
+
freqs_imag,
|
|
186
|
+
q.stride(1),
|
|
187
|
+
k.stride(1),
|
|
188
|
+
q.stride(2),
|
|
189
|
+
k.stride(2),
|
|
190
|
+
freqs_real.stride(0),
|
|
191
|
+
seq_len,
|
|
192
|
+
batch_size,
|
|
193
|
+
imag_sign,
|
|
194
|
+
head_dim_half,
|
|
195
|
+
n_q_heads,
|
|
196
|
+
n_k_heads,
|
|
197
|
+
BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
num_stages=2,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Cast back to original dtype only if it differs from compute dtype
|
|
203
|
+
if q.dtype != original_dtype:
|
|
204
|
+
q = q.to(original_dtype)
|
|
205
|
+
if k.dtype != original_dtype:
|
|
206
|
+
k = k.to(original_dtype)
|
|
207
|
+
|
|
208
|
+
return q, k
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class LigerLlama4RopeFunction(torch.autograd.Function):
|
|
212
|
+
@staticmethod
|
|
213
|
+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
|
|
214
|
+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
|
|
215
|
+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
|
|
216
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
217
|
+
return q_out, k_out
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def backward(ctx, dq, dk):
|
|
221
|
+
(freqs_cis,) = ctx.saved_tensors
|
|
222
|
+
BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
|
|
223
|
+
# Use imag_sign=-1.0 for conjugate without materializing a new tensor
|
|
224
|
+
dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
|
|
225
|
+
return dq_out, dk_out, None
|