liger-kernel-nightly 0.5.9.dev20250519011716__py3-none-any.whl → 0.5.9.dev20250519015630__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/RECORD +8 -6
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.9.dev20250519011716.dist-info → liger_kernel_nightly-0.5.9.dev20250519015630.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.jit
|
7
|
+
def _selective_log_softmax_kernel(
|
8
|
+
LOGITS,
|
9
|
+
INPUT_IDS,
|
10
|
+
LOG_P,
|
11
|
+
MASK,
|
12
|
+
TEMPERATURE,
|
13
|
+
stride_input_ids_b,
|
14
|
+
L: tl.constexpr,
|
15
|
+
N: tl.constexpr,
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
17
|
+
):
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
20
|
+
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
23
|
+
LOG_P += off_b * L + off_l
|
24
|
+
|
25
|
+
if MASK is not None:
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
27
|
+
not_skip = tl.load(MASK)
|
28
|
+
if not_skip == 0:
|
29
|
+
return
|
30
|
+
|
31
|
+
m_i = float("-inf")
|
32
|
+
l_i = 0.0
|
33
|
+
for start in range(0, N, BLOCK_N):
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
39
|
+
m_i = new_m_i
|
40
|
+
lse = m_i + tl.log(l_i)
|
41
|
+
|
42
|
+
ids = tl.load(INPUT_IDS)
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
44
|
+
logp = x - lse
|
45
|
+
tl.store(LOG_P, logp)
|
46
|
+
|
47
|
+
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
49
|
+
@torch.no_grad
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
51
|
+
assert logits.is_contiguous()
|
52
|
+
B, L_ADD_1, N = logits.shape
|
53
|
+
L = L_ADD_1 - 1
|
54
|
+
input_ids = input_ids[:, -L:]
|
55
|
+
if mask is not None:
|
56
|
+
mask = mask[:, -L:]
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
61
|
+
)
|
62
|
+
return log_p
|
63
|
+
|
64
|
+
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
67
|
+
# for ns in [1, 2, 4]
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
69
|
+
# key=['N'])
|
70
|
+
@triton.jit
|
71
|
+
def _grpo_loss_fwd_kernel(
|
72
|
+
LOGITS,
|
73
|
+
OLD_LOGP,
|
74
|
+
REF_LOGP,
|
75
|
+
INPUT_IDS,
|
76
|
+
COMPLETION_MASK,
|
77
|
+
ADVANTAGES,
|
78
|
+
LOSS,
|
79
|
+
LSE,
|
80
|
+
KL,
|
81
|
+
IS_CLIPPED,
|
82
|
+
TEMPERATURE,
|
83
|
+
BETA: tl.constexpr,
|
84
|
+
EPS_LOW,
|
85
|
+
EPS_HIGH,
|
86
|
+
L: tl.constexpr,
|
87
|
+
N: tl.constexpr,
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
89
|
+
):
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
92
|
+
|
93
|
+
if COMPLETION_MASK is not None:
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
96
|
+
if not_skip == 0:
|
97
|
+
return
|
98
|
+
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
101
|
+
ADVANTAGES += off_b
|
102
|
+
LOSS += off_b * L + off_l
|
103
|
+
LSE += off_b * L + off_l
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
105
|
+
|
106
|
+
m_i = float("-inf")
|
107
|
+
l_i = 0.0
|
108
|
+
for start in range(0, N, BLOCK_N):
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
114
|
+
m_i = new_m_i
|
115
|
+
lse = m_i + tl.log(l_i)
|
116
|
+
|
117
|
+
idx = tl.load(INPUT_IDS)
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
119
|
+
logp = x - lse
|
120
|
+
if OLD_LOGP is None:
|
121
|
+
old_logp = logp
|
122
|
+
else:
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
131
|
+
is_clipped = per_token_loss1 < per_token_loss2
|
132
|
+
|
133
|
+
if BETA != 0.0:
|
134
|
+
REF_LOGP += off_b * L + off_l
|
135
|
+
KL += off_b * L + off_l
|
136
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
137
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
138
|
+
per_token_loss += BETA * kl
|
139
|
+
tl.store(KL, kl)
|
140
|
+
|
141
|
+
tl.store(LOSS, per_token_loss)
|
142
|
+
tl.store(LSE, lse)
|
143
|
+
tl.store(IS_CLIPPED, is_clipped)
|
144
|
+
|
145
|
+
|
146
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
147
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
148
|
+
# for ns in [1, 2, 4]
|
149
|
+
# for nw in [1, 2, 4, 8, 16]],
|
150
|
+
# key=['N'])
|
151
|
+
@triton.jit
|
152
|
+
def _grpo_loss_bwd_kernel(
|
153
|
+
DLOSS,
|
154
|
+
DLOGITS,
|
155
|
+
LOGITS,
|
156
|
+
OLD_LOGP,
|
157
|
+
REF_LOGP,
|
158
|
+
INPUT_IDS,
|
159
|
+
ADVANTAGES,
|
160
|
+
COMPLETION_MASK,
|
161
|
+
LSE,
|
162
|
+
TEMPERATURE,
|
163
|
+
BETA: tl.constexpr,
|
164
|
+
EPS_LOW,
|
165
|
+
EPS_HIGH,
|
166
|
+
loss_stride0,
|
167
|
+
loss_stride1,
|
168
|
+
L: tl.constexpr,
|
169
|
+
N: tl.constexpr,
|
170
|
+
BLOCK_N: tl.constexpr = 4096,
|
171
|
+
):
|
172
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
173
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
174
|
+
|
175
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
176
|
+
if COMPLETION_MASK is not None:
|
177
|
+
COMPLETION_MASK += off_b * L + off_l
|
178
|
+
not_skip = tl.load(COMPLETION_MASK)
|
179
|
+
if not_skip == 0:
|
180
|
+
for start in range(0, N, BLOCK_N):
|
181
|
+
cols = tl.arange(0, BLOCK_N) + start
|
182
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
183
|
+
return
|
184
|
+
|
185
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
186
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
187
|
+
INPUT_IDS += off_b * L + off_l
|
188
|
+
ADVANTAGES += off_b
|
189
|
+
LSE += off_b * L + off_l
|
190
|
+
|
191
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
192
|
+
lse = tl.load(LSE).to(tl.float32)
|
193
|
+
|
194
|
+
idx = tl.load(INPUT_IDS)
|
195
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
196
|
+
logp = x - lse
|
197
|
+
if OLD_LOGP is None:
|
198
|
+
old_logp = logp
|
199
|
+
else:
|
200
|
+
OLD_LOGP += off_b * L + off_l
|
201
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
202
|
+
coef_1 = tl.exp(logp - old_logp)
|
203
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
204
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
205
|
+
per_token_loss1 = coef_1 * advantage
|
206
|
+
per_token_loss2 = coef_2 * advantage
|
207
|
+
mask = per_token_loss2 >= per_token_loss1
|
208
|
+
|
209
|
+
dlogp = -per_token_loss1 * mask
|
210
|
+
if BETA != 0.0:
|
211
|
+
REF_LOGP += off_b * L + off_l
|
212
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
213
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
214
|
+
|
215
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
216
|
+
tl.debug_barrier()
|
217
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
218
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
219
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
220
|
+
probs = tl.exp(logits - lse)
|
221
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
222
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
223
|
+
|
224
|
+
|
225
|
+
class GrpoLossFunction(torch.autograd.Function):
|
226
|
+
@staticmethod
|
227
|
+
def forward(
|
228
|
+
ctx,
|
229
|
+
logits,
|
230
|
+
old_logp,
|
231
|
+
ref_logp,
|
232
|
+
completion_ids,
|
233
|
+
advantages,
|
234
|
+
completion_mask,
|
235
|
+
temperature,
|
236
|
+
beta,
|
237
|
+
eps_low,
|
238
|
+
eps_high,
|
239
|
+
inplace,
|
240
|
+
):
|
241
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
242
|
+
assert old_logp is None or old_logp.is_contiguous()
|
243
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
244
|
+
|
245
|
+
B, L_ADD_1, N = logits.shape
|
246
|
+
L = L_ADD_1 - 1
|
247
|
+
|
248
|
+
if completion_mask is not None:
|
249
|
+
assert completion_mask.is_contiguous()
|
250
|
+
|
251
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
252
|
+
lse = torch.zeros_like(loss)
|
253
|
+
is_clipped = torch.zeros_like(loss)
|
254
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
255
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
256
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
257
|
+
logits,
|
258
|
+
old_logp,
|
259
|
+
ref_logp,
|
260
|
+
completion_ids,
|
261
|
+
completion_mask,
|
262
|
+
advantages,
|
263
|
+
loss,
|
264
|
+
lse,
|
265
|
+
kl,
|
266
|
+
is_clipped,
|
267
|
+
temperature,
|
268
|
+
beta,
|
269
|
+
eps_low,
|
270
|
+
eps_high,
|
271
|
+
L,
|
272
|
+
N,
|
273
|
+
**kwargs,
|
274
|
+
)
|
275
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
276
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
277
|
+
# return loss
|
278
|
+
return loss, kl, is_clipped
|
279
|
+
|
280
|
+
@staticmethod
|
281
|
+
def backward(ctx, *args):
|
282
|
+
dloss = args[0]
|
283
|
+
# print(dloss.shape)
|
284
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
285
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
286
|
+
B, L_ADD_1, N = logits.shape
|
287
|
+
L = L_ADD_1 - 1
|
288
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
289
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
290
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
291
|
+
dloss,
|
292
|
+
dlogits,
|
293
|
+
logits,
|
294
|
+
old_logp,
|
295
|
+
ref_logp,
|
296
|
+
completion_ids,
|
297
|
+
advantages,
|
298
|
+
completion_mask,
|
299
|
+
lse,
|
300
|
+
temperature,
|
301
|
+
beta,
|
302
|
+
eps_low,
|
303
|
+
eps_high,
|
304
|
+
*dloss.stride(),
|
305
|
+
L,
|
306
|
+
N,
|
307
|
+
**kwargs,
|
308
|
+
)
|
309
|
+
dlogits[:, -1, :] = 0
|
310
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction
|
2
|
+
|
3
|
+
|
4
|
+
def triton_grpo_loss(
|
5
|
+
logits,
|
6
|
+
old_logp,
|
7
|
+
ref_logp,
|
8
|
+
completion_ids,
|
9
|
+
advantages,
|
10
|
+
completion_mask=None,
|
11
|
+
temperature=0.9,
|
12
|
+
beta=0.04,
|
13
|
+
eps_low=0.2,
|
14
|
+
eps_high=0.4,
|
15
|
+
inplace=True,
|
16
|
+
):
|
17
|
+
assert logits is not None and completion_ids is not None and advantages is not None, (
|
18
|
+
"must provide logits、completion_ids and advantages"
|
19
|
+
)
|
20
|
+
|
21
|
+
return GrpoLossFunction.apply(
|
22
|
+
logits,
|
23
|
+
old_logp,
|
24
|
+
ref_logp,
|
25
|
+
completion_ids,
|
26
|
+
advantages,
|
27
|
+
completion_mask,
|
28
|
+
temperature,
|
29
|
+
beta,
|
30
|
+
eps_low,
|
31
|
+
eps_high,
|
32
|
+
inplace,
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
|
37
|
+
"""
|
38
|
+
import torch
|
39
|
+
import trl
|
40
|
+
assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
|
41
|
+
from trl.extras.profiling import profiling_decorator
|
42
|
+
|
43
|
+
@profiling_decorator
|
44
|
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
45
|
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
46
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
47
|
+
return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
|
48
|
+
|
49
|
+
@profiling_decorator
|
50
|
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
51
|
+
if return_outputs:
|
52
|
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
53
|
+
# Compute the per-token log probabilities for the model
|
54
|
+
|
55
|
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
56
|
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
57
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
58
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
59
|
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
60
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
61
|
+
|
62
|
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
63
|
+
advantages = inputs["advantages"]
|
64
|
+
old_per_token_logps = inputs["old_per_token_logps"]
|
65
|
+
|
66
|
+
|
67
|
+
per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
|
68
|
+
old_per_token_logps,
|
69
|
+
ref_per_token_logps,
|
70
|
+
completion_ids,
|
71
|
+
advantages,
|
72
|
+
completion_mask,
|
73
|
+
self.temperature,
|
74
|
+
self.beta,
|
75
|
+
self.epsilon_low,
|
76
|
+
self.epsilon_high,)
|
77
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
78
|
+
|
79
|
+
# Log the metrics
|
80
|
+
mode = "eval" if self.control.should_evaluate else "train"
|
81
|
+
|
82
|
+
if self.beta != 0.0:
|
83
|
+
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
84
|
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
85
|
+
|
86
|
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
87
|
+
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
88
|
+
return loss
|
89
|
+
|
90
|
+
trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
|
91
|
+
trl.GRPOTrainer.compute_loss = compute_loss
|
92
|
+
trigger = None
|
93
|
+
"""
|
94
|
+
|
95
|
+
# add this line at the first line of grpo.py in open-r1
|
96
|
+
"""
|
97
|
+
from liger_kernel.transformers.grpo_loss import trigger
|
98
|
+
"""
|
@@ -22,6 +22,7 @@ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzT
|
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
23
23
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
24
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
25
|
+
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
25
26
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
26
27
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
27
28
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
@@ -44,6 +45,7 @@ liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJl
|
|
44
45
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
45
46
|
liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
|
46
47
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
48
|
+
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
47
49
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
48
50
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
49
51
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
@@ -79,9 +81,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
79
81
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
80
82
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
81
83
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
82
|
-
liger_kernel_nightly-0.5.9.
|
83
|
-
liger_kernel_nightly-0.5.9.
|
84
|
-
liger_kernel_nightly-0.5.9.
|
85
|
-
liger_kernel_nightly-0.5.9.
|
86
|
-
liger_kernel_nightly-0.5.9.
|
87
|
-
liger_kernel_nightly-0.5.9.
|
84
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
85
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/METADATA,sha256=_HRxosGQvS3kYalXZIxjmOinoXb0PoA0kSVBH3SbuHg,23970
|
86
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
87
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
88
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
89
|
+
liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|