liger-kernel-nightly 0.5.2.dev20241220220835__py3-none-any.whl → 0.5.2.dev20241223032015__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +2 -2
- liger_kernel/chunked_loss/fused_linear_preference.py +1 -1
- liger_kernel/chunked_loss/orpo_loss.py +1 -1
- liger_kernel/chunked_loss/simpo_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +2 -2
- liger_kernel/ops/kl_div.py +4 -4
- liger_kernel/ops/rms_norm.py +3 -3
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/RECORD +13 -13
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241220220835.dist-info → liger_kernel_nightly-0.5.2.dev20241223032015.dist-info}/top_level.txt +0 -0
@@ -36,8 +36,8 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
36
36
|
"""
|
37
37
|
logits = beta * (chosen_logps - rejected_logps)
|
38
38
|
loss = (
|
39
|
-
F.logsigmoid(logits) * (1 - label_smoothing)
|
40
|
-
|
39
|
+
- F.logsigmoid(logits) * (1 - label_smoothing)
|
40
|
+
- F.logsigmoid(-logits) * label_smoothing
|
41
41
|
).sum() / (full_target.shape[0] // 2)
|
42
42
|
|
43
43
|
return loss
|
@@ -408,7 +408,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
408
408
|
else:
|
409
409
|
preference_loss, aux_outputs = preference_loss_outputs, []
|
410
410
|
|
411
|
-
loss = alpha * chosen_nll_loss
|
411
|
+
loss = alpha * chosen_nll_loss + preference_loss
|
412
412
|
return_vars = (
|
413
413
|
chosen_logps,
|
414
414
|
rejected_logps,
|
@@ -36,7 +36,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
36
36
|
- torch.log1p(-torch.exp(rejected_logps))
|
37
37
|
)
|
38
38
|
ratio = F.logsigmoid(log_odds)
|
39
|
-
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
|
39
|
+
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
40
40
|
|
41
41
|
chosen_rewards = beta * chosen_logps
|
42
42
|
rejected_rewards = beta * rejected_logps
|
@@ -42,8 +42,8 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
42
42
|
"""
|
43
43
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
44
44
|
loss = (
|
45
|
-
F.logsigmoid(logits) * (1 - label_smoothing)
|
46
|
-
|
45
|
+
- F.logsigmoid(logits) * (1 - label_smoothing)
|
46
|
+
- F.logsigmoid(-logits) * label_smoothing
|
47
47
|
).sum() / (full_target.shape[0] // 2)
|
48
48
|
|
49
49
|
return loss
|
@@ -17,8 +17,8 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
17
17
|
else:
|
18
18
|
from triton.language.math import tanh
|
19
19
|
|
20
|
-
_TRUE = tl.constexpr(1)
|
21
|
-
_FALSE = tl.constexpr(0)
|
20
|
+
_TRUE: tl.constexpr = tl.constexpr(1)
|
21
|
+
_FALSE: tl.constexpr = tl.constexpr(0)
|
22
22
|
|
23
23
|
|
24
24
|
@triton.jit
|
liger_kernel/ops/kl_div.py
CHANGED
@@ -23,10 +23,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
23
23
|
|
24
24
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
25
25
|
|
26
|
-
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
27
|
-
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
28
|
-
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
29
|
-
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
26
|
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
27
|
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
28
|
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
29
|
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
30
30
|
|
31
31
|
_str_to_reduction_mode = {
|
32
32
|
"none": _REDUCTION_MODE_NONE.value,
|
liger_kernel/ops/rms_norm.py
CHANGED
@@ -35,9 +35,9 @@ else:
|
|
35
35
|
from triton.language.math import rsqrt
|
36
36
|
|
37
37
|
|
38
|
-
_CASTING_MODE_NONE = tl.constexpr(-1)
|
39
|
-
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
40
|
-
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
38
|
+
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
|
39
|
+
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
|
40
|
+
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
|
41
41
|
|
42
42
|
|
43
43
|
@triton.jit
|
@@ -3,24 +3,24 @@ liger_kernel/env_report.py,sha256=ok9PMXtO-8uLj_feCJI4h9hz2NtolZ2AG_OJTW5qmo4,18
|
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=3PdSp1gju1u0ffFGpUufbZPIva8aI3SW1TfqkJOpw1g,3554
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=jbTno1pKEc-HxAGFY3NSycBzdWyTacyRCzH3FhrMUMo,4383
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
11
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=
|
12
|
-
liger_kernel/chunked_loss/simpo_loss.py,sha256=
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vvratrj8rba8NaGbO2ffbUfWMVEvDMxDCo6SI8nCtbo,16376
|
11
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=xHsKjlCWQVew7_hhpyUp3a1wd0tdpgx-zQAezNjk3Q4,3532
|
12
|
+
liger_kernel/chunked_loss/simpo_loss.py,sha256=_5gXIkEAT0Kt_AufziQlYhBjzDJVSQVk7oSDHcrw1xw,3759
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=3oPrw6KzIVc11gSyfdrLnj0WJB4qOYjE1tC8HJeFFpg,15888
|
15
15
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
|
17
17
|
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
18
18
|
liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
|
19
19
|
liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
|
20
|
-
liger_kernel/ops/kl_div.py,sha256=
|
20
|
+
liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,8430
|
21
21
|
liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
|
22
22
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
|
23
|
-
liger_kernel/ops/rms_norm.py,sha256=
|
23
|
+
liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
|
24
24
|
liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
|
25
25
|
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
26
26
|
liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
|
59
59
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/METADATA,sha256=glSPMysElXhTUr1u74GrG_xjFSIek9GtE9AlPR6GkLs,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|