liger-kernel-nightly 0.4.2.dev20241209224333__tar.gz → 0.4.2.dev20241209234352__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.4.2.dev20241209224333/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241209234352}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/cpo_loss.py +16 -10
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/dpo_loss.py +20 -12
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/orpo_loss.py +15 -9
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/simpo_loss.py +17 -11
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/README.md +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/orpo_trainer.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.4.2.
|
7
|
+
version = "0.4.2.dev20241209234352"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
11
11
|
@staticmethod
|
12
12
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
|
-
|
14
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
15
|
+
|
16
|
+
Formula:
|
17
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
18
|
+
|
19
|
+
Where:
|
20
|
+
- π_θ(y|x): Policy (model) probability
|
21
|
+
- y_w: Chosen sequence
|
22
|
+
- y_l: Rejected sequence
|
23
|
+
- σ: Sigmoid function
|
24
|
+
- β: Temperature parameter
|
25
|
+
- E: Expected value over the dataset D
|
26
|
+
- D: Dataset of preferences
|
27
|
+
|
15
28
|
Args:
|
16
29
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
17
30
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
18
|
-
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
32
|
+
beta (float): Weight for the CPO loss
|
19
33
|
"""
|
20
34
|
logits = beta * (chosen_logps - rejected_logps)
|
21
35
|
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
@@ -34,12 +48,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
34
48
|
compute_nll_loss=True,
|
35
49
|
compiled=True,
|
36
50
|
):
|
37
|
-
"""
|
38
|
-
Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
|
39
|
-
Handles both the forward and backward pass of the final linear layer with CPO loss.
|
40
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
41
|
-
"""
|
42
|
-
|
43
51
|
return LigerFusedLinearPreferenceBase.forward(
|
44
52
|
ctx,
|
45
53
|
_input,
|
@@ -56,9 +64,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
56
64
|
|
57
65
|
@staticmethod
|
58
66
|
def backward(ctx, *grad_output):
|
59
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
60
67
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
61
|
-
# Return these gradients, followed by None for the remaining inputs
|
62
68
|
return *grads, None, None, None, None, None
|
63
69
|
|
64
70
|
|
@@ -18,14 +18,28 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
18
18
|
beta=0.1,
|
19
19
|
):
|
20
20
|
"""
|
21
|
-
|
21
|
+
Paper: https://arxiv.org/pdf/2305.18290
|
22
|
+
|
23
|
+
Formula:
|
24
|
+
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
|
25
|
+
|
26
|
+
Where:
|
27
|
+
- π(y|x): Policy (model) probability
|
28
|
+
- π_ref(y|x): Reference model probability
|
29
|
+
- y_w: Chosen sequence
|
30
|
+
- y_l: Rejected sequence
|
31
|
+
- β: Weight for the direct preference loss
|
32
|
+
- E: Expected value over the dataset
|
33
|
+
|
22
34
|
Args:
|
23
|
-
chosen_logps
|
24
|
-
rejected_logps
|
25
|
-
|
26
|
-
|
27
|
-
|
35
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
36
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
37
|
+
full_target: Non chunked full target tensor
|
38
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
39
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
40
|
+
beta: Weight for the direct preference loss
|
28
41
|
"""
|
42
|
+
|
29
43
|
if ref_chosen_logps is None:
|
30
44
|
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
|
31
45
|
if ref_rejected_logps is None:
|
@@ -53,10 +67,6 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
53
67
|
compiled=True,
|
54
68
|
use_ref_model=True,
|
55
69
|
):
|
56
|
-
"""
|
57
|
-
Fused linear layer with DPO (Direct Preference Optimization) loss.
|
58
|
-
Handles both the forward and backward pass of the final linear layer with DPO loss.
|
59
|
-
"""
|
60
70
|
return LigerFusedLinearPreferenceBase.forward(
|
61
71
|
ctx=ctx,
|
62
72
|
_input=_input,
|
@@ -75,9 +85,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
75
85
|
|
76
86
|
@staticmethod
|
77
87
|
def backward(ctx, *grad_output):
|
78
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
79
88
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
80
|
-
# Return these gradients, followed by None for the remaining inputs
|
81
89
|
return *grads, None, None, None, None, None, None, None
|
82
90
|
|
83
91
|
|
@@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
11
11
|
@staticmethod
|
12
12
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
|
-
|
14
|
+
Paper: https://arxiv.org/pdf/2403.07691
|
15
|
+
|
16
|
+
Formula:
|
17
|
+
Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
|
18
|
+
where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
|
19
|
+
|
20
|
+
Where:
|
21
|
+
- P_θ(y|x): Policy (model) probability
|
22
|
+
- y_w: Chosen sequence
|
23
|
+
- y_l: Rejected sequence
|
24
|
+
- σ: Sigmoid function
|
25
|
+
- β: Weight for the odds ratio loss
|
26
|
+
- odds_θ: Odds function for the policy
|
27
|
+
|
15
28
|
Args:
|
16
29
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
17
30
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
18
32
|
beta (float): Weight for the odds ratio loss.
|
19
33
|
"""
|
20
34
|
log_odds = (chosen_logps - rejected_logps) - (
|
@@ -44,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
44
58
|
compute_nll_loss=True,
|
45
59
|
compiled=True,
|
46
60
|
):
|
47
|
-
"""
|
48
|
-
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
|
49
|
-
Handles both the forward and backward pass of the final linear layer with ORPO loss.
|
50
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
51
|
-
"""
|
52
|
-
|
53
61
|
return LigerFusedLinearPreferenceBase.forward(
|
54
62
|
ctx=ctx,
|
55
63
|
_input=_input,
|
@@ -65,9 +73,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
65
73
|
|
66
74
|
@staticmethod
|
67
75
|
def backward(ctx, *grad_output):
|
68
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
69
76
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
70
|
-
# Return these gradients, followed by None for the remaining inputs
|
71
77
|
return *grads, None, None, None, None
|
72
78
|
|
73
79
|
|
@@ -13,12 +13,26 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
13
13
|
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
|
14
14
|
):
|
15
15
|
"""
|
16
|
-
|
16
|
+
Paper: https://arxiv.org/pdf/2405.14734
|
17
|
+
|
18
|
+
Formula:
|
19
|
+
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
|
20
|
+
|
21
|
+
Where:
|
22
|
+
- π_θ(y|x): Policy (model) probability
|
23
|
+
- y_w: Chosen sequence
|
24
|
+
- y_l: Rejected sequence
|
25
|
+
- |y_w|, |y_l|: Sequence lengths
|
26
|
+
- σ: Sigmoid function
|
27
|
+
- β: beta weight
|
28
|
+
- γ: gemma margin term
|
29
|
+
|
17
30
|
Args:
|
18
31
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
19
32
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
20
|
-
|
21
|
-
|
33
|
+
full_target: Non chunked full target tensor
|
34
|
+
beta (float): beta weight
|
35
|
+
gamma (float): gemma margin term
|
22
36
|
"""
|
23
37
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
24
38
|
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
@@ -38,12 +52,6 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
38
52
|
compiled=True,
|
39
53
|
gamma=0.5,
|
40
54
|
):
|
41
|
-
"""
|
42
|
-
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
|
43
|
-
Handles both the forward and backward pass of the final linear layer with SimPO loss.
|
44
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
45
|
-
"""
|
46
|
-
|
47
55
|
return LigerFusedLinearPreferenceBase.forward(
|
48
56
|
ctx,
|
49
57
|
_input,
|
@@ -61,9 +69,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
61
69
|
|
62
70
|
@staticmethod
|
63
71
|
def backward(ctx, *grad_output):
|
64
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
65
72
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
66
|
-
# Return these gradients, followed by None for the remaining inputs
|
67
73
|
return *grads, None, None, None, None, None, None
|
68
74
|
|
69
75
|
|
File without changes
|
{liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|