liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +25 -9
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +124 -64
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +13 -6
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +283 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +205 -19
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +50 -25
- liger_kernel/transformers/model/gemma2.py +55 -23
- liger_kernel/transformers/model/gemma3.py +117 -120
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +102 -25
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +36 -23
- liger_kernel/transformers/model/mixtral.py +45 -25
- liger_kernel/transformers/model/mllama.py +39 -22
- liger_kernel/transformers/model/olmo2.py +40 -20
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -177
- liger_kernel/transformers/model/qwen2.py +48 -21
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1678 -160
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +48 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +36 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
|
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
|
|
|
40
41
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
41
42
|
tanh_result = tanh(tanh_arg)
|
|
42
43
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
43
|
-
c_row = geglu_a * b_row
|
|
44
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
44
45
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
45
46
|
|
|
46
47
|
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _selective_log_softmax_kernel(
|
|
8
|
+
LOGITS,
|
|
9
|
+
INPUT_IDS,
|
|
10
|
+
LOG_P,
|
|
11
|
+
MASK,
|
|
12
|
+
TEMPERATURE,
|
|
13
|
+
stride_input_ids_b,
|
|
14
|
+
L: tl.constexpr,
|
|
15
|
+
N: tl.constexpr,
|
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
17
|
+
):
|
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
20
|
+
|
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
|
23
|
+
LOG_P += off_b * L + off_l
|
|
24
|
+
|
|
25
|
+
if MASK is not None:
|
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
|
27
|
+
not_skip = tl.load(MASK)
|
|
28
|
+
if not_skip == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
m_i = float("-inf")
|
|
32
|
+
l_i = 0.0
|
|
33
|
+
for start in range(0, N, BLOCK_N):
|
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
39
|
+
m_i = new_m_i
|
|
40
|
+
lse = m_i + tl.log(l_i)
|
|
41
|
+
|
|
42
|
+
ids = tl.load(INPUT_IDS)
|
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
|
44
|
+
logp = x - lse
|
|
45
|
+
tl.store(LOG_P, logp)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
|
51
|
+
assert logits.is_contiguous()
|
|
52
|
+
B, L_ADD_1, N = logits.shape
|
|
53
|
+
L = L_ADD_1 - 1
|
|
54
|
+
input_ids = input_ids[:, -L:]
|
|
55
|
+
if mask is not None:
|
|
56
|
+
mask = mask[:, -L:]
|
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
|
61
|
+
)
|
|
62
|
+
return log_p
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
67
|
+
# for ns in [1, 2, 4]
|
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
69
|
+
# key=['N'])
|
|
70
|
+
@triton.jit
|
|
71
|
+
def _grpo_loss_fwd_kernel(
|
|
72
|
+
LOGITS,
|
|
73
|
+
OLD_LOGP,
|
|
74
|
+
REF_LOGP,
|
|
75
|
+
INPUT_IDS,
|
|
76
|
+
COMPLETION_MASK,
|
|
77
|
+
ADVANTAGES,
|
|
78
|
+
LOSS,
|
|
79
|
+
LSE,
|
|
80
|
+
KL,
|
|
81
|
+
IS_CLIPPED,
|
|
82
|
+
TEMPERATURE,
|
|
83
|
+
BETA: tl.constexpr,
|
|
84
|
+
EPS_LOW,
|
|
85
|
+
EPS_HIGH,
|
|
86
|
+
L: tl.constexpr,
|
|
87
|
+
N: tl.constexpr,
|
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
89
|
+
):
|
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
92
|
+
|
|
93
|
+
if COMPLETION_MASK is not None:
|
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
96
|
+
if not_skip == 0:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
|
101
|
+
ADVANTAGES += off_b
|
|
102
|
+
LOSS += off_b * L + off_l
|
|
103
|
+
LSE += off_b * L + off_l
|
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
|
105
|
+
|
|
106
|
+
m_i = float("-inf")
|
|
107
|
+
l_i = 0.0
|
|
108
|
+
for start in range(0, N, BLOCK_N):
|
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
114
|
+
m_i = new_m_i
|
|
115
|
+
lse = m_i + tl.log(l_i)
|
|
116
|
+
|
|
117
|
+
idx = tl.load(INPUT_IDS)
|
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
119
|
+
logp = x - lse
|
|
120
|
+
if OLD_LOGP is None:
|
|
121
|
+
old_logp = logp
|
|
122
|
+
else:
|
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
134
|
+
|
|
135
|
+
if BETA != 0.0:
|
|
136
|
+
REF_LOGP += off_b * L + off_l
|
|
137
|
+
KL += off_b * L + off_l
|
|
138
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
139
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
|
140
|
+
per_token_loss += BETA * kl
|
|
141
|
+
tl.store(KL, kl)
|
|
142
|
+
|
|
143
|
+
tl.store(LOSS, per_token_loss)
|
|
144
|
+
tl.store(LSE, lse)
|
|
145
|
+
tl.store(IS_CLIPPED, is_clipped)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
149
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
150
|
+
# for ns in [1, 2, 4]
|
|
151
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
152
|
+
# key=['N'])
|
|
153
|
+
@triton.jit
|
|
154
|
+
def _grpo_loss_bwd_kernel(
|
|
155
|
+
DLOSS,
|
|
156
|
+
DLOGITS,
|
|
157
|
+
LOGITS,
|
|
158
|
+
OLD_LOGP,
|
|
159
|
+
REF_LOGP,
|
|
160
|
+
INPUT_IDS,
|
|
161
|
+
ADVANTAGES,
|
|
162
|
+
COMPLETION_MASK,
|
|
163
|
+
LSE,
|
|
164
|
+
TEMPERATURE,
|
|
165
|
+
BETA: tl.constexpr,
|
|
166
|
+
EPS_LOW,
|
|
167
|
+
EPS_HIGH,
|
|
168
|
+
loss_stride0,
|
|
169
|
+
loss_stride1,
|
|
170
|
+
L: tl.constexpr,
|
|
171
|
+
N: tl.constexpr,
|
|
172
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
173
|
+
):
|
|
174
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
175
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
176
|
+
|
|
177
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
|
178
|
+
if COMPLETION_MASK is not None:
|
|
179
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
180
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
181
|
+
if not_skip == 0:
|
|
182
|
+
for start in range(0, N, BLOCK_N):
|
|
183
|
+
cols = tl.arange(0, BLOCK_N) + start
|
|
184
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
188
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
|
189
|
+
INPUT_IDS += off_b * L + off_l
|
|
190
|
+
ADVANTAGES += off_b
|
|
191
|
+
LSE += off_b * L + off_l
|
|
192
|
+
|
|
193
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
|
194
|
+
lse = tl.load(LSE).to(tl.float32)
|
|
195
|
+
|
|
196
|
+
idx = tl.load(INPUT_IDS)
|
|
197
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
198
|
+
logp = x - lse
|
|
199
|
+
if OLD_LOGP is None:
|
|
200
|
+
old_logp = logp
|
|
201
|
+
else:
|
|
202
|
+
OLD_LOGP += off_b * L + off_l
|
|
203
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
204
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
205
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
206
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
207
|
+
per_token_loss1 = coef_1 * advantage
|
|
208
|
+
per_token_loss2 = coef_2 * advantage
|
|
209
|
+
mask = per_token_loss2 >= per_token_loss1
|
|
210
|
+
|
|
211
|
+
dlogp = -per_token_loss1 * mask
|
|
212
|
+
if BETA != 0.0:
|
|
213
|
+
REF_LOGP += off_b * L + off_l
|
|
214
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
215
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
|
216
|
+
|
|
217
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
|
218
|
+
tl.debug_barrier()
|
|
219
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
|
220
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
|
221
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
|
222
|
+
probs = tl.exp(logits - lse)
|
|
223
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
|
224
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class GrpoLossFunction(torch.autograd.Function):
|
|
228
|
+
@staticmethod
|
|
229
|
+
def forward(
|
|
230
|
+
ctx,
|
|
231
|
+
logits,
|
|
232
|
+
old_logp,
|
|
233
|
+
ref_logp,
|
|
234
|
+
completion_ids,
|
|
235
|
+
advantages,
|
|
236
|
+
completion_mask,
|
|
237
|
+
temperature,
|
|
238
|
+
beta,
|
|
239
|
+
eps_low,
|
|
240
|
+
eps_high,
|
|
241
|
+
inplace,
|
|
242
|
+
):
|
|
243
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
|
244
|
+
assert old_logp is None or old_logp.is_contiguous()
|
|
245
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
|
246
|
+
|
|
247
|
+
B, L_ADD_1, N = logits.shape
|
|
248
|
+
L = L_ADD_1 - 1
|
|
249
|
+
|
|
250
|
+
if completion_mask is not None:
|
|
251
|
+
assert completion_mask.is_contiguous()
|
|
252
|
+
|
|
253
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
|
254
|
+
lse = torch.zeros_like(loss)
|
|
255
|
+
is_clipped = torch.zeros_like(loss)
|
|
256
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
|
257
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
|
258
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
|
259
|
+
logits,
|
|
260
|
+
old_logp,
|
|
261
|
+
ref_logp,
|
|
262
|
+
completion_ids,
|
|
263
|
+
completion_mask,
|
|
264
|
+
advantages,
|
|
265
|
+
loss,
|
|
266
|
+
lse,
|
|
267
|
+
kl,
|
|
268
|
+
is_clipped,
|
|
269
|
+
temperature,
|
|
270
|
+
beta,
|
|
271
|
+
eps_low,
|
|
272
|
+
eps_high,
|
|
273
|
+
L,
|
|
274
|
+
N,
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
|
278
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
|
279
|
+
# return loss
|
|
280
|
+
return loss, kl, is_clipped
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def backward(ctx, *args):
|
|
284
|
+
dloss = args[0]
|
|
285
|
+
# print(dloss.shape)
|
|
286
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
|
287
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
|
288
|
+
B, L_ADD_1, N = logits.shape
|
|
289
|
+
L = L_ADD_1 - 1
|
|
290
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
|
291
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
|
292
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
|
293
|
+
dloss,
|
|
294
|
+
dlogits,
|
|
295
|
+
logits,
|
|
296
|
+
old_logp,
|
|
297
|
+
ref_logp,
|
|
298
|
+
completion_ids,
|
|
299
|
+
advantages,
|
|
300
|
+
completion_mask,
|
|
301
|
+
lse,
|
|
302
|
+
temperature,
|
|
303
|
+
beta,
|
|
304
|
+
eps_low,
|
|
305
|
+
eps_high,
|
|
306
|
+
*dloss.stride(),
|
|
307
|
+
L,
|
|
308
|
+
N,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
dlogits[:, -1, :] = 0
|
|
312
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.utils import infer_device
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@triton.jit
|
|
@@ -92,7 +93,7 @@ def _jsd_kernel(
|
|
|
92
93
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
93
94
|
|
|
94
95
|
|
|
95
|
-
MAX_FUSED_SIZE = 65536
|
|
96
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def get_num_warps(BLOCK_SIZE):
|
|
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
|
|
|
115
116
|
|
|
116
117
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
117
118
|
BT, V = y_pred.shape
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
BLOCK_SIZE = (
|
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
|
121
|
+
if infer_device() == "xpu"
|
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
+
)
|
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
121
125
|
|
|
122
126
|
grid = (BT,)
|
|
123
127
|
reduction = _str_to_reduction_mode[reduction]
|
|
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
155
159
|
|
|
156
160
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
157
161
|
BT, V = target.shape
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
162
|
+
BLOCK_SIZE = (
|
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
|
164
|
+
if infer_device() == "xpu"
|
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
+
)
|
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
161
168
|
|
|
162
169
|
grid = (BT,)
|
|
163
170
|
|