liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
|
|
|
40
41
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
41
42
|
tanh_result = tanh(tanh_arg)
|
|
42
43
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
43
|
-
c_row = geglu_a * b_row
|
|
44
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
44
45
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
45
46
|
|
|
46
47
|
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
66
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
67
68
|
tanh_result = tanh(tanh_arg)
|
|
68
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
69
71
|
|
|
70
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
71
73
|
|
|
72
74
|
# Gradient w.r.t. a can be computed with:
|
|
73
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
78
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
79
81
|
|
|
80
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
81
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def geglu_forward(a, b):
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -77,15 +78,14 @@ def _group_norm_forward_kernel(
|
|
|
77
78
|
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
78
79
|
W = tl.load(W_ptr + channel_idx)
|
|
79
80
|
B = tl.load(B_ptr + channel_idx)
|
|
80
|
-
|
|
81
|
+
# Calculate channel offset within the group
|
|
82
|
+
channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
|
|
83
|
+
for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
81
84
|
hidden_size_offsets = i + block_range
|
|
82
85
|
mask = hidden_size_offsets < hidden_size_per_channel
|
|
83
|
-
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
86
|
+
X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
|
|
84
87
|
Y = (X - m) * rstd * W + B
|
|
85
|
-
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
86
|
-
|
|
87
|
-
X_ptr += hidden_size_per_channel
|
|
88
|
-
Y_ptr += hidden_size_per_channel
|
|
88
|
+
tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
|
|
89
89
|
|
|
90
90
|
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
91
91
|
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _selective_log_softmax_kernel(
|
|
8
|
+
LOGITS,
|
|
9
|
+
INPUT_IDS,
|
|
10
|
+
LOG_P,
|
|
11
|
+
MASK,
|
|
12
|
+
TEMPERATURE,
|
|
13
|
+
stride_input_ids_b,
|
|
14
|
+
L: tl.constexpr,
|
|
15
|
+
N: tl.constexpr,
|
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
17
|
+
):
|
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
20
|
+
|
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
|
23
|
+
LOG_P += off_b * L + off_l
|
|
24
|
+
|
|
25
|
+
if MASK is not None:
|
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
|
27
|
+
not_skip = tl.load(MASK)
|
|
28
|
+
if not_skip == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
m_i = float("-inf")
|
|
32
|
+
l_i = 0.0
|
|
33
|
+
for start in range(0, N, BLOCK_N):
|
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
39
|
+
m_i = new_m_i
|
|
40
|
+
lse = m_i + tl.log(l_i)
|
|
41
|
+
|
|
42
|
+
ids = tl.load(INPUT_IDS)
|
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
|
44
|
+
logp = x - lse
|
|
45
|
+
tl.store(LOG_P, logp)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
|
51
|
+
assert logits.is_contiguous()
|
|
52
|
+
B, L_ADD_1, N = logits.shape
|
|
53
|
+
L = L_ADD_1 - 1
|
|
54
|
+
input_ids = input_ids[:, -L:]
|
|
55
|
+
if mask is not None:
|
|
56
|
+
mask = mask[:, -L:]
|
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
|
61
|
+
)
|
|
62
|
+
return log_p
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
67
|
+
# for ns in [1, 2, 4]
|
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
69
|
+
# key=['N'])
|
|
70
|
+
@triton.jit
|
|
71
|
+
def _grpo_loss_fwd_kernel(
|
|
72
|
+
LOGITS,
|
|
73
|
+
OLD_LOGP,
|
|
74
|
+
REF_LOGP,
|
|
75
|
+
INPUT_IDS,
|
|
76
|
+
COMPLETION_MASK,
|
|
77
|
+
ADVANTAGES,
|
|
78
|
+
LOSS,
|
|
79
|
+
LSE,
|
|
80
|
+
KL,
|
|
81
|
+
IS_CLIPPED,
|
|
82
|
+
TEMPERATURE,
|
|
83
|
+
BETA: tl.constexpr,
|
|
84
|
+
EPS_LOW,
|
|
85
|
+
EPS_HIGH,
|
|
86
|
+
L: tl.constexpr,
|
|
87
|
+
N: tl.constexpr,
|
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
89
|
+
):
|
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
92
|
+
|
|
93
|
+
if COMPLETION_MASK is not None:
|
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
96
|
+
if not_skip == 0:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
|
101
|
+
ADVANTAGES += off_b
|
|
102
|
+
LOSS += off_b * L + off_l
|
|
103
|
+
LSE += off_b * L + off_l
|
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
|
105
|
+
|
|
106
|
+
m_i = float("-inf")
|
|
107
|
+
l_i = 0.0
|
|
108
|
+
for start in range(0, N, BLOCK_N):
|
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
114
|
+
m_i = new_m_i
|
|
115
|
+
lse = m_i + tl.log(l_i)
|
|
116
|
+
|
|
117
|
+
idx = tl.load(INPUT_IDS)
|
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
119
|
+
logp = x - lse
|
|
120
|
+
if OLD_LOGP is None:
|
|
121
|
+
old_logp = logp
|
|
122
|
+
else:
|
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
134
|
+
|
|
135
|
+
if BETA != 0.0:
|
|
136
|
+
REF_LOGP += off_b * L + off_l
|
|
137
|
+
KL += off_b * L + off_l
|
|
138
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
139
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
|
140
|
+
per_token_loss += BETA * kl
|
|
141
|
+
tl.store(KL, kl)
|
|
142
|
+
|
|
143
|
+
tl.store(LOSS, per_token_loss)
|
|
144
|
+
tl.store(LSE, lse)
|
|
145
|
+
tl.store(IS_CLIPPED, is_clipped)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
149
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
150
|
+
# for ns in [1, 2, 4]
|
|
151
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
152
|
+
# key=['N'])
|
|
153
|
+
@triton.jit
|
|
154
|
+
def _grpo_loss_bwd_kernel(
|
|
155
|
+
DLOSS,
|
|
156
|
+
DLOGITS,
|
|
157
|
+
LOGITS,
|
|
158
|
+
OLD_LOGP,
|
|
159
|
+
REF_LOGP,
|
|
160
|
+
INPUT_IDS,
|
|
161
|
+
ADVANTAGES,
|
|
162
|
+
COMPLETION_MASK,
|
|
163
|
+
LSE,
|
|
164
|
+
TEMPERATURE,
|
|
165
|
+
BETA: tl.constexpr,
|
|
166
|
+
EPS_LOW,
|
|
167
|
+
EPS_HIGH,
|
|
168
|
+
loss_stride0,
|
|
169
|
+
loss_stride1,
|
|
170
|
+
L: tl.constexpr,
|
|
171
|
+
N: tl.constexpr,
|
|
172
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
173
|
+
):
|
|
174
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
175
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
176
|
+
|
|
177
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
|
178
|
+
if COMPLETION_MASK is not None:
|
|
179
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
180
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
181
|
+
if not_skip == 0:
|
|
182
|
+
for start in range(0, N, BLOCK_N):
|
|
183
|
+
cols = tl.arange(0, BLOCK_N) + start
|
|
184
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
188
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
|
189
|
+
INPUT_IDS += off_b * L + off_l
|
|
190
|
+
ADVANTAGES += off_b
|
|
191
|
+
LSE += off_b * L + off_l
|
|
192
|
+
|
|
193
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
|
194
|
+
lse = tl.load(LSE).to(tl.float32)
|
|
195
|
+
|
|
196
|
+
idx = tl.load(INPUT_IDS)
|
|
197
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
198
|
+
logp = x - lse
|
|
199
|
+
if OLD_LOGP is None:
|
|
200
|
+
old_logp = logp
|
|
201
|
+
else:
|
|
202
|
+
OLD_LOGP += off_b * L + off_l
|
|
203
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
204
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
205
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
206
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
207
|
+
per_token_loss1 = coef_1 * advantage
|
|
208
|
+
per_token_loss2 = coef_2 * advantage
|
|
209
|
+
mask = per_token_loss2 >= per_token_loss1
|
|
210
|
+
|
|
211
|
+
dlogp = -per_token_loss1 * mask
|
|
212
|
+
if BETA != 0.0:
|
|
213
|
+
REF_LOGP += off_b * L + off_l
|
|
214
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
215
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
|
216
|
+
|
|
217
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
|
218
|
+
tl.debug_barrier()
|
|
219
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
|
220
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
|
221
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
|
222
|
+
probs = tl.exp(logits - lse)
|
|
223
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
|
224
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class GrpoLossFunction(torch.autograd.Function):
|
|
228
|
+
@staticmethod
|
|
229
|
+
def forward(
|
|
230
|
+
ctx,
|
|
231
|
+
logits,
|
|
232
|
+
old_logp,
|
|
233
|
+
ref_logp,
|
|
234
|
+
completion_ids,
|
|
235
|
+
advantages,
|
|
236
|
+
completion_mask,
|
|
237
|
+
temperature,
|
|
238
|
+
beta,
|
|
239
|
+
eps_low,
|
|
240
|
+
eps_high,
|
|
241
|
+
inplace,
|
|
242
|
+
):
|
|
243
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
|
244
|
+
assert old_logp is None or old_logp.is_contiguous()
|
|
245
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
|
246
|
+
|
|
247
|
+
B, L_ADD_1, N = logits.shape
|
|
248
|
+
L = L_ADD_1 - 1
|
|
249
|
+
|
|
250
|
+
if completion_mask is not None:
|
|
251
|
+
assert completion_mask.is_contiguous()
|
|
252
|
+
|
|
253
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
|
254
|
+
lse = torch.zeros_like(loss)
|
|
255
|
+
is_clipped = torch.zeros_like(loss)
|
|
256
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
|
257
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
|
258
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
|
259
|
+
logits,
|
|
260
|
+
old_logp,
|
|
261
|
+
ref_logp,
|
|
262
|
+
completion_ids,
|
|
263
|
+
completion_mask,
|
|
264
|
+
advantages,
|
|
265
|
+
loss,
|
|
266
|
+
lse,
|
|
267
|
+
kl,
|
|
268
|
+
is_clipped,
|
|
269
|
+
temperature,
|
|
270
|
+
beta,
|
|
271
|
+
eps_low,
|
|
272
|
+
eps_high,
|
|
273
|
+
L,
|
|
274
|
+
N,
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
|
278
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
|
279
|
+
# return loss
|
|
280
|
+
return loss, kl, is_clipped
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def backward(ctx, *args):
|
|
284
|
+
dloss = args[0]
|
|
285
|
+
# print(dloss.shape)
|
|
286
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
|
287
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
|
288
|
+
B, L_ADD_1, N = logits.shape
|
|
289
|
+
L = L_ADD_1 - 1
|
|
290
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
|
291
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
|
292
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
|
293
|
+
dloss,
|
|
294
|
+
dlogits,
|
|
295
|
+
logits,
|
|
296
|
+
old_logp,
|
|
297
|
+
ref_logp,
|
|
298
|
+
completion_ids,
|
|
299
|
+
advantages,
|
|
300
|
+
completion_mask,
|
|
301
|
+
lse,
|
|
302
|
+
temperature,
|
|
303
|
+
beta,
|
|
304
|
+
eps_low,
|
|
305
|
+
eps_high,
|
|
306
|
+
*dloss.stride(),
|
|
307
|
+
L,
|
|
308
|
+
N,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
dlogits[:, -1, :] = 0
|
|
312
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.utils import infer_device
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@triton.jit
|
|
@@ -92,7 +93,7 @@ def _jsd_kernel(
|
|
|
92
93
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
93
94
|
|
|
94
95
|
|
|
95
|
-
MAX_FUSED_SIZE = 65536
|
|
96
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def get_num_warps(BLOCK_SIZE):
|
|
@@ -20,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
|
20
21
|
return num_warps
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
if infer_device() == "xpu":
|
|
25
|
+
MAX_FUSED_SIZE = 8192
|
|
26
|
+
elif infer_device() == "npu":
|
|
27
|
+
MAX_FUSED_SIZE = 8192
|
|
28
|
+
else:
|
|
29
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
24
30
|
|
|
25
31
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
26
32
|
|
|
@@ -115,9 +121,8 @@ def _kldiv_kernel_backward(
|
|
|
115
121
|
|
|
116
122
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
117
123
|
BT, V = y_pred.shape
|
|
118
|
-
|
|
119
124
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
120
|
-
num_warps = get_num_warps(BLOCK_SIZE)
|
|
125
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
121
126
|
|
|
122
127
|
grid = (BT,)
|
|
123
128
|
reduction = _str_to_reduction_mode[reduction]
|
|
@@ -155,9 +160,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
155
160
|
|
|
156
161
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
157
162
|
BT, V = target.shape
|
|
158
|
-
|
|
159
163
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
160
|
-
num_warps = get_num_warps(BLOCK_SIZE)
|
|
164
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
161
165
|
|
|
162
166
|
grid = (BT,)
|
|
163
167
|
|