liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _selective_log_softmax_kernel(
|
|
8
|
+
LOGITS,
|
|
9
|
+
INPUT_IDS,
|
|
10
|
+
LOG_P,
|
|
11
|
+
MASK,
|
|
12
|
+
TEMPERATURE,
|
|
13
|
+
stride_input_ids_b,
|
|
14
|
+
L: tl.constexpr,
|
|
15
|
+
N: tl.constexpr,
|
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
17
|
+
):
|
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
20
|
+
|
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
|
23
|
+
LOG_P += off_b * L + off_l
|
|
24
|
+
|
|
25
|
+
if MASK is not None:
|
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
|
27
|
+
not_skip = tl.load(MASK)
|
|
28
|
+
if not_skip == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
m_i = float("-inf")
|
|
32
|
+
l_i = 0.0
|
|
33
|
+
for start in range(0, N, BLOCK_N):
|
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
39
|
+
m_i = new_m_i
|
|
40
|
+
lse = m_i + tl.log(l_i)
|
|
41
|
+
|
|
42
|
+
ids = tl.load(INPUT_IDS)
|
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
|
44
|
+
logp = x - lse
|
|
45
|
+
tl.store(LOG_P, logp)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
|
51
|
+
assert logits.is_contiguous()
|
|
52
|
+
B, L_ADD_1, N = logits.shape
|
|
53
|
+
L = L_ADD_1 - 1
|
|
54
|
+
input_ids = input_ids[:, -L:]
|
|
55
|
+
if mask is not None:
|
|
56
|
+
mask = mask[:, -L:]
|
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
|
61
|
+
)
|
|
62
|
+
return log_p
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
67
|
+
# for ns in [1, 2, 4]
|
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
69
|
+
# key=['N'])
|
|
70
|
+
@triton.jit
|
|
71
|
+
def _grpo_loss_fwd_kernel(
|
|
72
|
+
LOGITS,
|
|
73
|
+
OLD_LOGP,
|
|
74
|
+
REF_LOGP,
|
|
75
|
+
INPUT_IDS,
|
|
76
|
+
COMPLETION_MASK,
|
|
77
|
+
ADVANTAGES,
|
|
78
|
+
LOSS,
|
|
79
|
+
LSE,
|
|
80
|
+
KL,
|
|
81
|
+
IS_CLIPPED,
|
|
82
|
+
TEMPERATURE,
|
|
83
|
+
BETA: tl.constexpr,
|
|
84
|
+
EPS_LOW,
|
|
85
|
+
EPS_HIGH,
|
|
86
|
+
L: tl.constexpr,
|
|
87
|
+
N: tl.constexpr,
|
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
89
|
+
):
|
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
92
|
+
|
|
93
|
+
if COMPLETION_MASK is not None:
|
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
96
|
+
if not_skip == 0:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
|
101
|
+
ADVANTAGES += off_b
|
|
102
|
+
LOSS += off_b * L + off_l
|
|
103
|
+
LSE += off_b * L + off_l
|
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
|
105
|
+
|
|
106
|
+
m_i = float("-inf")
|
|
107
|
+
l_i = 0.0
|
|
108
|
+
for start in range(0, N, BLOCK_N):
|
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
114
|
+
m_i = new_m_i
|
|
115
|
+
lse = m_i + tl.log(l_i)
|
|
116
|
+
|
|
117
|
+
idx = tl.load(INPUT_IDS)
|
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
119
|
+
logp = x - lse
|
|
120
|
+
if OLD_LOGP is None:
|
|
121
|
+
old_logp = logp
|
|
122
|
+
else:
|
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
134
|
+
|
|
135
|
+
if BETA != 0.0:
|
|
136
|
+
REF_LOGP += off_b * L + off_l
|
|
137
|
+
KL += off_b * L + off_l
|
|
138
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
139
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
|
140
|
+
per_token_loss += BETA * kl
|
|
141
|
+
tl.store(KL, kl)
|
|
142
|
+
|
|
143
|
+
tl.store(LOSS, per_token_loss)
|
|
144
|
+
tl.store(LSE, lse)
|
|
145
|
+
tl.store(IS_CLIPPED, is_clipped)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
149
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
150
|
+
# for ns in [1, 2, 4]
|
|
151
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
152
|
+
# key=['N'])
|
|
153
|
+
@triton.jit
|
|
154
|
+
def _grpo_loss_bwd_kernel(
|
|
155
|
+
DLOSS,
|
|
156
|
+
DLOGITS,
|
|
157
|
+
LOGITS,
|
|
158
|
+
OLD_LOGP,
|
|
159
|
+
REF_LOGP,
|
|
160
|
+
INPUT_IDS,
|
|
161
|
+
ADVANTAGES,
|
|
162
|
+
COMPLETION_MASK,
|
|
163
|
+
LSE,
|
|
164
|
+
TEMPERATURE,
|
|
165
|
+
BETA: tl.constexpr,
|
|
166
|
+
EPS_LOW,
|
|
167
|
+
EPS_HIGH,
|
|
168
|
+
loss_stride0,
|
|
169
|
+
loss_stride1,
|
|
170
|
+
L: tl.constexpr,
|
|
171
|
+
N: tl.constexpr,
|
|
172
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
173
|
+
):
|
|
174
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
175
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
176
|
+
|
|
177
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
|
178
|
+
if COMPLETION_MASK is not None:
|
|
179
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
180
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
181
|
+
if not_skip == 0:
|
|
182
|
+
for start in range(0, N, BLOCK_N):
|
|
183
|
+
cols = tl.arange(0, BLOCK_N) + start
|
|
184
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
188
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
|
189
|
+
INPUT_IDS += off_b * L + off_l
|
|
190
|
+
ADVANTAGES += off_b
|
|
191
|
+
LSE += off_b * L + off_l
|
|
192
|
+
|
|
193
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
|
194
|
+
lse = tl.load(LSE).to(tl.float32)
|
|
195
|
+
|
|
196
|
+
idx = tl.load(INPUT_IDS)
|
|
197
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
198
|
+
logp = x - lse
|
|
199
|
+
if OLD_LOGP is None:
|
|
200
|
+
old_logp = logp
|
|
201
|
+
else:
|
|
202
|
+
OLD_LOGP += off_b * L + off_l
|
|
203
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
204
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
205
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
206
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
207
|
+
per_token_loss1 = coef_1 * advantage
|
|
208
|
+
per_token_loss2 = coef_2 * advantage
|
|
209
|
+
mask = per_token_loss2 >= per_token_loss1
|
|
210
|
+
|
|
211
|
+
dlogp = -per_token_loss1 * mask
|
|
212
|
+
if BETA != 0.0:
|
|
213
|
+
REF_LOGP += off_b * L + off_l
|
|
214
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
215
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
|
216
|
+
|
|
217
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
|
218
|
+
tl.debug_barrier()
|
|
219
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
|
220
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
|
221
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
|
222
|
+
probs = tl.exp(logits - lse)
|
|
223
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
|
224
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class GrpoLossFunction(torch.autograd.Function):
|
|
228
|
+
@staticmethod
|
|
229
|
+
def forward(
|
|
230
|
+
ctx,
|
|
231
|
+
logits,
|
|
232
|
+
old_logp,
|
|
233
|
+
ref_logp,
|
|
234
|
+
completion_ids,
|
|
235
|
+
advantages,
|
|
236
|
+
completion_mask,
|
|
237
|
+
temperature,
|
|
238
|
+
beta,
|
|
239
|
+
eps_low,
|
|
240
|
+
eps_high,
|
|
241
|
+
inplace,
|
|
242
|
+
):
|
|
243
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
|
244
|
+
assert old_logp is None or old_logp.is_contiguous()
|
|
245
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
|
246
|
+
|
|
247
|
+
B, L_ADD_1, N = logits.shape
|
|
248
|
+
L = L_ADD_1 - 1
|
|
249
|
+
|
|
250
|
+
if completion_mask is not None:
|
|
251
|
+
assert completion_mask.is_contiguous()
|
|
252
|
+
|
|
253
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
|
254
|
+
lse = torch.zeros_like(loss)
|
|
255
|
+
is_clipped = torch.zeros_like(loss)
|
|
256
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
|
257
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
|
258
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
|
259
|
+
logits,
|
|
260
|
+
old_logp,
|
|
261
|
+
ref_logp,
|
|
262
|
+
completion_ids,
|
|
263
|
+
completion_mask,
|
|
264
|
+
advantages,
|
|
265
|
+
loss,
|
|
266
|
+
lse,
|
|
267
|
+
kl,
|
|
268
|
+
is_clipped,
|
|
269
|
+
temperature,
|
|
270
|
+
beta,
|
|
271
|
+
eps_low,
|
|
272
|
+
eps_high,
|
|
273
|
+
L,
|
|
274
|
+
N,
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
|
278
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
|
279
|
+
# return loss
|
|
280
|
+
return loss, kl, is_clipped
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def backward(ctx, *args):
|
|
284
|
+
dloss = args[0]
|
|
285
|
+
# print(dloss.shape)
|
|
286
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
|
287
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
|
288
|
+
B, L_ADD_1, N = logits.shape
|
|
289
|
+
L = L_ADD_1 - 1
|
|
290
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
|
291
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
|
292
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
|
293
|
+
dloss,
|
|
294
|
+
dlogits,
|
|
295
|
+
logits,
|
|
296
|
+
old_logp,
|
|
297
|
+
ref_logp,
|
|
298
|
+
completion_ids,
|
|
299
|
+
advantages,
|
|
300
|
+
completion_mask,
|
|
301
|
+
lse,
|
|
302
|
+
temperature,
|
|
303
|
+
beta,
|
|
304
|
+
eps_low,
|
|
305
|
+
eps_high,
|
|
306
|
+
*dloss.stride(),
|
|
307
|
+
L,
|
|
308
|
+
N,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
dlogits[:, -1, :] = 0
|
|
312
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|
liger_kernel/ops/jsd.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.utils import infer_device
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _jsd_kernel(
|
|
13
|
+
X_ptr, # input in logspace, X = log Q
|
|
14
|
+
X_stride,
|
|
15
|
+
Y_ptr, # ground truth in logspace, Y = log P
|
|
16
|
+
Y_stride,
|
|
17
|
+
loss_ptr,
|
|
18
|
+
loss_stride,
|
|
19
|
+
dX_ptr,
|
|
20
|
+
dX_stride,
|
|
21
|
+
label_ptr,
|
|
22
|
+
beta: tl.constexpr,
|
|
23
|
+
n_non_ignore: int,
|
|
24
|
+
ignore_index: tl.constexpr,
|
|
25
|
+
n_cols,
|
|
26
|
+
BLOCK_SIZE: tl.constexpr,
|
|
27
|
+
HAS_LABEL: tl.constexpr,
|
|
28
|
+
):
|
|
29
|
+
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
|
|
30
|
+
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
|
|
31
|
+
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
|
|
32
|
+
# grad_x_i = 0.5 * Q * (X - log_M)
|
|
33
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
34
|
+
X_ptr += pid * X_stride
|
|
35
|
+
dX_ptr += pid * dX_stride
|
|
36
|
+
Y_ptr += pid * Y_stride
|
|
37
|
+
loss_ptr += pid * loss_stride
|
|
38
|
+
label_ptr += pid
|
|
39
|
+
|
|
40
|
+
if HAS_LABEL:
|
|
41
|
+
label = tl.load(label_ptr)
|
|
42
|
+
if label == ignore_index:
|
|
43
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
44
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
45
|
+
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
49
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
50
|
+
mask = offsets < n_cols
|
|
51
|
+
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
|
+
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
53
|
+
|
|
54
|
+
if beta == 0.0: # forward KL
|
|
55
|
+
Y_max = tl.max(Y, axis=0)
|
|
56
|
+
Y_shifted = Y - Y_max
|
|
57
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
58
|
+
loss = Y_prob * (Y - X)
|
|
59
|
+
dX = -Y_prob
|
|
60
|
+
elif beta == 1.0: # reverse KL
|
|
61
|
+
X_max = tl.max(X, axis=0)
|
|
62
|
+
X_shifted = X - X_max
|
|
63
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
64
|
+
loss = X_prob * (X - Y)
|
|
65
|
+
dX = loss + X_prob
|
|
66
|
+
else:
|
|
67
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
68
|
+
X_shifted = X - max_val
|
|
69
|
+
Y_shifted = Y - max_val
|
|
70
|
+
|
|
71
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
72
|
+
exp_max = tl.exp(max_val)
|
|
73
|
+
|
|
74
|
+
# Compute exp terms with compensation
|
|
75
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
76
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
77
|
+
|
|
78
|
+
# Pre-compute common terms
|
|
79
|
+
beta_P = beta * P
|
|
80
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
81
|
+
M = beta_P + one_minus_beta_Q
|
|
82
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
83
|
+
|
|
84
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
85
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
86
|
+
|
|
87
|
+
# Pre-compute scaling factor
|
|
88
|
+
scale = 1.0 / n_non_ignore
|
|
89
|
+
loss = loss * scale
|
|
90
|
+
dX = dX * scale
|
|
91
|
+
|
|
92
|
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
93
|
+
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
100
|
+
BT, V = _input.shape
|
|
101
|
+
n_rows = BT
|
|
102
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
103
|
+
# non reduction loss
|
|
104
|
+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
|
|
105
|
+
dX = torch.empty_like(_input)
|
|
106
|
+
|
|
107
|
+
if has_label:
|
|
108
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item()
|
|
109
|
+
else:
|
|
110
|
+
n_non_ignore = BT
|
|
111
|
+
|
|
112
|
+
_jsd_kernel[(n_rows,)](
|
|
113
|
+
X_ptr=_input, # input in logspace, X = log Q
|
|
114
|
+
X_stride=_input.stride(-2),
|
|
115
|
+
Y_ptr=target, # ground truth in logspace, Y = log P
|
|
116
|
+
Y_stride=target.stride(-2),
|
|
117
|
+
loss_ptr=loss,
|
|
118
|
+
loss_stride=loss.stride(-2),
|
|
119
|
+
dX_ptr=dX,
|
|
120
|
+
dX_stride=dX.stride(-2),
|
|
121
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
|
122
|
+
beta=beta,
|
|
123
|
+
n_non_ignore=n_non_ignore,
|
|
124
|
+
ignore_index=ignore_index,
|
|
125
|
+
n_cols=V,
|
|
126
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
127
|
+
HAS_LABEL=has_label,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
loss = torch.sum(loss)
|
|
131
|
+
return loss.to(_input.dtype), dX
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def jsd_backward(dX, grad_output):
|
|
135
|
+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
136
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
137
|
+
return dX
|
|
138
|
+
else:
|
|
139
|
+
return grad_output * dX
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class LigerJSDFunction(torch.autograd.Function):
|
|
143
|
+
r"""
|
|
144
|
+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
|
|
145
|
+
.. math::
|
|
146
|
+
JSD(\beta)(P || Q)
|
|
147
|
+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
|
|
148
|
+
|
|
149
|
+
.. note::
|
|
150
|
+
As all the other losses in PyTorch, this function expects the first argument,
|
|
151
|
+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
|
|
152
|
+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
|
|
153
|
+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
|
|
154
|
+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
@ensure_contiguous
|
|
159
|
+
def forward(
|
|
160
|
+
ctx,
|
|
161
|
+
_input: torch.Tensor,
|
|
162
|
+
target: torch.Tensor,
|
|
163
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
164
|
+
beta: float = 0.5,
|
|
165
|
+
ignore_index: int = -100,
|
|
166
|
+
) -> torch.Tensor:
|
|
167
|
+
"""
|
|
168
|
+
Args:
|
|
169
|
+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
|
170
|
+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
|
171
|
+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
172
|
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
173
|
+
ignore_index (int): the index to ignore. Default: -100
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
loss (torch.Tensor): generalized JSD
|
|
177
|
+
"""
|
|
178
|
+
has_label = False
|
|
179
|
+
if shift_labels is not None:
|
|
180
|
+
assert shift_labels.shape == (_input.shape[0],), (
|
|
181
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
182
|
+
)
|
|
183
|
+
shift_labels = shift_labels.contiguous()
|
|
184
|
+
has_label = True
|
|
185
|
+
|
|
186
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
|
187
|
+
ctx.save_for_backward(dX)
|
|
188
|
+
return loss
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@ensure_contiguous
|
|
192
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
193
|
+
(dX,) = ctx.saved_tensors
|
|
194
|
+
dX = jsd_backward(dX, grad_output)
|
|
195
|
+
return (
|
|
196
|
+
dX,
|
|
197
|
+
None,
|
|
198
|
+
None,
|
|
199
|
+
None,
|
|
200
|
+
None,
|
|
201
|
+
)
|