liger-kernel-nightly 0.5.2.dev20250101082227__py3-none-any.whl → 0.5.2.dev20250108072837__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -1
- liger_kernel/chunked_loss/simpo_loss.py +4 -1
- liger_kernel/ops/cross_entropy.py +3 -2
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250101082227.dist-info → liger_kernel_nightly-0.5.2.dev20250108072837.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,11 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
33
33
|
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
34
34
|
full_target.shape[0] // 2
|
35
35
|
)
|
36
|
-
|
36
|
+
|
37
|
+
chosen_rewards = beta * chosen_logps
|
38
|
+
rejected_rewards = beta * rejected_logps
|
39
|
+
|
40
|
+
return loss, chosen_rewards, rejected_rewards
|
37
41
|
|
38
42
|
@staticmethod
|
39
43
|
def forward(
|
@@ -42,7 +42,10 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
42
42
|
full_target.shape[0] // 2
|
43
43
|
)
|
44
44
|
|
45
|
-
|
45
|
+
chosen_rewards = beta * chosen_logps
|
46
|
+
rejected_rewards = beta * rejected_logps
|
47
|
+
|
48
|
+
return loss, chosen_rewards, rejected_rewards
|
46
49
|
|
47
50
|
@staticmethod
|
48
51
|
def forward(
|
@@ -95,7 +95,8 @@ def liger_cross_entropy_kernel(
|
|
95
95
|
return
|
96
96
|
|
97
97
|
loss_ptr += program_id * loss_stride
|
98
|
-
|
98
|
+
if RETURN_Z_LOSS == _TRUE:
|
99
|
+
z_loss_ptr += program_id * loss_stride
|
99
100
|
|
100
101
|
if HAS_WEIGHT:
|
101
102
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
@@ -296,7 +297,7 @@ def cross_entropy_forward(
|
|
296
297
|
if return_z_loss == _TRUE.value:
|
297
298
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
298
299
|
else:
|
299
|
-
z_loss_1d =
|
300
|
+
z_loss_1d = None # set None when return_z_loss == False
|
300
301
|
|
301
302
|
target_mask = target != ignore_index
|
302
303
|
n_non_ignore = target_mask.sum().item()
|
@@ -3,15 +3,15 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
|
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=MCR4TzuBoJEaU0IJ7dIreLacQeXLKETV5CegNjhCD9M,3646
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
10
10
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
|
11
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
|
-
liger_kernel/chunked_loss/simpo_loss.py,sha256=
|
12
|
+
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=zi2xsa8ky7M1vySUAGjXMQDFQFkKmGQV-myRIIQM13M,19210
|
15
15
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
17
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/METADATA,sha256=HwmQEBRYnwwbdkzuW53_qsmTSSbi8qu20cVOHsq6B_s,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|