liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Liger FlexChunkLoss: Alignment and Distillation loss
|
|
2
|
+
|
|
3
|
+
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
|
|
4
|
+
|
|
5
|
+
### User interface
|
|
6
|
+
|
|
7
|
+
FlexChunkLoss offers two flexible usage options:
|
|
8
|
+
|
|
9
|
+
1. **Via `Liger[Custom Loss]Trainer`**
|
|
10
|
+
For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
|
|
11
|
+
|
|
12
|
+
2. **Using `nn.Module` Implementations of Custom Loss Functions**
|
|
13
|
+
Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
|
|
14
|
+
|
|
15
|
+
### What's under the hood?
|
|
16
|
+
|
|
17
|
+
We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
|
|
18
|
+
|
|
19
|
+
### Extending to custom loss functions
|
|
20
|
+
|
|
21
|
+
We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
|
|
22
|
+
|
|
23
|
+
To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
|
|
24
|
+
|
|
25
|
+
For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
2
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
4
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
5
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
|
6
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
|
7
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
8
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
|
13
|
+
"""
|
|
14
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
15
|
+
Args:
|
|
16
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
17
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
18
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: cosine similarity loss
|
|
21
|
+
"""
|
|
22
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
23
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
24
|
+
|
|
25
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
26
|
+
loss = beta * (1 - cosine_sim)
|
|
27
|
+
return loss.sum()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def forward(
|
|
31
|
+
cls,
|
|
32
|
+
ctx,
|
|
33
|
+
student_input: torch.Tensor,
|
|
34
|
+
student_weight: torch.Tensor,
|
|
35
|
+
teacher_input: torch.Tensor,
|
|
36
|
+
teacher_weight: torch.Tensor,
|
|
37
|
+
true_labels: torch.LongTensor,
|
|
38
|
+
student_bias: torch.Tensor,
|
|
39
|
+
teacher_bias: torch.Tensor,
|
|
40
|
+
weight_hard_loss: float = 0.5,
|
|
41
|
+
weight_soft_loss: float = 0.5,
|
|
42
|
+
beta: float = 0.5,
|
|
43
|
+
ignore_index: int = -100,
|
|
44
|
+
temperature: float = 1.0,
|
|
45
|
+
compiled: bool = True,
|
|
46
|
+
chunk_size: int = 1024,
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
49
|
+
return super().forward(
|
|
50
|
+
cls=cls,
|
|
51
|
+
ctx=ctx,
|
|
52
|
+
student_input=student_input,
|
|
53
|
+
student_weight=student_weight,
|
|
54
|
+
teacher_input=teacher_input,
|
|
55
|
+
teacher_weight=teacher_weight,
|
|
56
|
+
target=true_labels,
|
|
57
|
+
student_bias=student_bias,
|
|
58
|
+
teacher_bias=teacher_bias,
|
|
59
|
+
chunk_size=chunk_size,
|
|
60
|
+
weight_hard_loss=weight_hard_loss,
|
|
61
|
+
weight_soft_loss=weight_soft_loss,
|
|
62
|
+
beta=beta,
|
|
63
|
+
ignore_index=ignore_index,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
72
|
+
|
|
73
|
+
return (
|
|
74
|
+
*grads,
|
|
75
|
+
None, # teacher_bias
|
|
76
|
+
None, # weight_hard_loss
|
|
77
|
+
None, # weight_soft_loss
|
|
78
|
+
None, # beta
|
|
79
|
+
None, # ignore_index
|
|
80
|
+
None, # temperature
|
|
81
|
+
None, # compiled
|
|
82
|
+
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_hard_loss: float = 0.5,
|
|
91
|
+
weight_soft_loss: float = 0.5,
|
|
92
|
+
beta: float = 0.5,
|
|
93
|
+
ignore_index: int = -100,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
compiled: bool = True,
|
|
96
|
+
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
101
|
+
self.weight_hard_loss = weight_hard_loss
|
|
102
|
+
self.weight_soft_loss = weight_soft_loss
|
|
103
|
+
self.ignore_index = ignore_index
|
|
104
|
+
self.temperature = temperature
|
|
105
|
+
self.compiled = compiled
|
|
106
|
+
self.beta = beta
|
|
107
|
+
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self,
|
|
112
|
+
student_input: torch.Tensor,
|
|
113
|
+
student_weight: torch.Tensor,
|
|
114
|
+
teacher_input: torch.Tensor,
|
|
115
|
+
teacher_weight: torch.Tensor,
|
|
116
|
+
true_labels: torch.LongTensor,
|
|
117
|
+
student_bias: torch.Tensor = None,
|
|
118
|
+
teacher_bias: torch.Tensor = None,
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
120
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
121
|
+
student_input,
|
|
122
|
+
student_weight,
|
|
123
|
+
teacher_input,
|
|
124
|
+
teacher_weight,
|
|
125
|
+
true_labels,
|
|
126
|
+
student_bias,
|
|
127
|
+
teacher_bias,
|
|
128
|
+
self.weight_hard_loss,
|
|
129
|
+
self.weight_soft_loss,
|
|
130
|
+
self.beta,
|
|
131
|
+
self.ignore_index,
|
|
132
|
+
self.temperature,
|
|
133
|
+
self.compiled,
|
|
134
|
+
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
136
|
+
)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
|
|
10
|
+
"""
|
|
11
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
|
12
|
+
|
|
13
|
+
Formula:
|
|
14
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
|
15
|
+
|
|
16
|
+
Where:
|
|
17
|
+
- π_θ(y|x): Policy (model) probability
|
|
18
|
+
- y_w: Chosen sequence
|
|
19
|
+
- y_l: Rejected sequence
|
|
20
|
+
- σ: Sigmoid function
|
|
21
|
+
- β: Temperature parameter
|
|
22
|
+
- E: Expected value over the dataset D
|
|
23
|
+
- D: Dataset of preferences
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
27
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
28
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
|
29
|
+
beta (float): Weight for the CPO loss
|
|
30
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
31
|
+
"""
|
|
32
|
+
logits = beta * (chosen_logps - rejected_logps)
|
|
33
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
34
|
+
full_target.shape[0] // 2
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
chosen_rewards = beta * chosen_logps
|
|
38
|
+
rejected_rewards = beta * rejected_logps
|
|
39
|
+
|
|
40
|
+
return loss, chosen_rewards, rejected_rewards
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def forward(
|
|
44
|
+
cls,
|
|
45
|
+
ctx,
|
|
46
|
+
_input,
|
|
47
|
+
weight,
|
|
48
|
+
target,
|
|
49
|
+
bias=None,
|
|
50
|
+
ignore_index=-100,
|
|
51
|
+
beta=0.1,
|
|
52
|
+
alpha=1.0,
|
|
53
|
+
label_smoothing=0.0,
|
|
54
|
+
compute_nll_loss=True,
|
|
55
|
+
compiled=True,
|
|
56
|
+
average_log_prob=False,
|
|
57
|
+
chunk_size=1,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Fused linear layer with CPO loss.
|
|
61
|
+
Args:
|
|
62
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
63
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
64
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
65
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
66
|
+
ignore_index (int): Index to ignore in loss computation
|
|
67
|
+
beta (float): Weight for the odds ratio loss
|
|
68
|
+
alpha (float): Weight for the alpha parameter
|
|
69
|
+
label_smoothing (float): Label smoothing factor
|
|
70
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
73
|
+
chunk_size (int): Size of chunks for processing.
|
|
74
|
+
Returns:
|
|
75
|
+
torch.Tensor: Computed loss
|
|
76
|
+
"""
|
|
77
|
+
return super().forward(
|
|
78
|
+
cls=cls,
|
|
79
|
+
ctx=ctx,
|
|
80
|
+
_input=_input,
|
|
81
|
+
weight=weight,
|
|
82
|
+
target=target,
|
|
83
|
+
bias=bias,
|
|
84
|
+
ignore_index=ignore_index,
|
|
85
|
+
alpha=alpha,
|
|
86
|
+
beta=beta,
|
|
87
|
+
label_smoothing=label_smoothing,
|
|
88
|
+
compute_nll_loss=compute_nll_loss,
|
|
89
|
+
average_log_prob=average_log_prob,
|
|
90
|
+
compiled=compiled,
|
|
91
|
+
chunk_size=chunk_size,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def backward(ctx, *grad_output):
|
|
96
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
97
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
101
|
+
"""
|
|
102
|
+
Fused linear layer with CPO loss.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
ignore_index: int = -100,
|
|
108
|
+
beta: float = 0.1,
|
|
109
|
+
alpha: float = 1.0,
|
|
110
|
+
label_smoothing: float = 0.0,
|
|
111
|
+
compute_nll_loss: bool = True,
|
|
112
|
+
compiled: bool = True,
|
|
113
|
+
average_log_prob: bool = False,
|
|
114
|
+
chunk_size: int = 1,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Args:
|
|
118
|
+
ignore_index (int): Index to ignore in the loss.
|
|
119
|
+
beta (float): Weight for the odds ratio loss.
|
|
120
|
+
alpha (float): Weight for the alpha parameter.
|
|
121
|
+
label_smoothing (float): Label smoothing factor.
|
|
122
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
123
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
124
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
125
|
+
chunk_size (int): Size of chunks for processing.
|
|
126
|
+
"""
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.ignore_index = ignore_index
|
|
129
|
+
self.beta = beta
|
|
130
|
+
self.alpha = alpha
|
|
131
|
+
self.label_smoothing = label_smoothing
|
|
132
|
+
self.compute_nll_loss = compute_nll_loss
|
|
133
|
+
self.compiled = compiled
|
|
134
|
+
self.average_log_prob = average_log_prob
|
|
135
|
+
self.chunk_size = chunk_size
|
|
136
|
+
|
|
137
|
+
def forward(
|
|
138
|
+
self,
|
|
139
|
+
lin_weight,
|
|
140
|
+
_input,
|
|
141
|
+
target,
|
|
142
|
+
bias=None,
|
|
143
|
+
):
|
|
144
|
+
return LigerFusedLinearCPOFunction.apply(
|
|
145
|
+
_input,
|
|
146
|
+
lin_weight,
|
|
147
|
+
target,
|
|
148
|
+
bias,
|
|
149
|
+
self.ignore_index,
|
|
150
|
+
self.beta,
|
|
151
|
+
self.alpha,
|
|
152
|
+
self.label_smoothing,
|
|
153
|
+
self.compute_nll_loss,
|
|
154
|
+
self.compiled,
|
|
155
|
+
self.average_log_prob,
|
|
156
|
+
self.chunk_size,
|
|
157
|
+
)
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
chosen_logps,
|
|
11
|
+
rejected_logps,
|
|
12
|
+
full_target,
|
|
13
|
+
ref_chosen_logps=None,
|
|
14
|
+
ref_rejected_logps=None,
|
|
15
|
+
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Paper: https://arxiv.org/pdf/2305.18290
|
|
20
|
+
|
|
21
|
+
Formula:
|
|
22
|
+
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
|
|
23
|
+
|
|
24
|
+
Where:
|
|
25
|
+
- π(y|x): Policy (model) probability
|
|
26
|
+
- π_ref(y|x): Reference model probability
|
|
27
|
+
- y_w: Chosen sequence
|
|
28
|
+
- y_l: Rejected sequence
|
|
29
|
+
- β: Weight for the direct preference loss
|
|
30
|
+
- E: Expected value over the dataset
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
|
34
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
|
35
|
+
full_target: Non chunked full target tensor
|
|
36
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
|
37
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
|
38
|
+
beta: Weight for the direct preference loss
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
if ref_chosen_logps is None:
|
|
42
|
+
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
|
|
43
|
+
if ref_rejected_logps is None:
|
|
44
|
+
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
|
|
45
|
+
|
|
46
|
+
chosen_logratios = chosen_logps - ref_chosen_logps
|
|
47
|
+
rejected_logratios = rejected_logps - ref_rejected_logps
|
|
48
|
+
|
|
49
|
+
chosen_rewards = beta * chosen_logratios
|
|
50
|
+
rejected_rewards = beta * rejected_logratios
|
|
51
|
+
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return loss, chosen_rewards, rejected_rewards
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def forward(
|
|
100
|
+
cls,
|
|
101
|
+
ctx,
|
|
102
|
+
_input,
|
|
103
|
+
weight,
|
|
104
|
+
target,
|
|
105
|
+
bias=None,
|
|
106
|
+
ref_input=None,
|
|
107
|
+
ref_weight=None,
|
|
108
|
+
ref_bias=None,
|
|
109
|
+
ignore_index=-100,
|
|
110
|
+
beta=0.1,
|
|
111
|
+
compute_nll_loss=False,
|
|
112
|
+
compiled=True,
|
|
113
|
+
use_ref_model=True,
|
|
114
|
+
average_log_prob=False,
|
|
115
|
+
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Fused linear layer with DPO loss.
|
|
120
|
+
Args:
|
|
121
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
122
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
123
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
124
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
125
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
126
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
127
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
128
|
+
ignore_index (int): Index to ignore in loss computation
|
|
129
|
+
beta (float): Weight for the odds ratio loss
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
131
|
+
compiled (bool): Whether to use torch compile
|
|
132
|
+
use_ref_model (bool): Whether to use a reference model
|
|
133
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
134
|
+
chunk_size (int): Size of chunks for processing.
|
|
135
|
+
Returns:
|
|
136
|
+
torch.Tensor: Computed loss
|
|
137
|
+
"""
|
|
138
|
+
return super().forward(
|
|
139
|
+
cls=cls,
|
|
140
|
+
ctx=ctx,
|
|
141
|
+
_input=_input,
|
|
142
|
+
weight=weight,
|
|
143
|
+
target=target,
|
|
144
|
+
bias=bias,
|
|
145
|
+
ignore_index=ignore_index,
|
|
146
|
+
beta=beta,
|
|
147
|
+
compute_nll_loss=compute_nll_loss,
|
|
148
|
+
compiled=compiled,
|
|
149
|
+
use_ref_model=use_ref_model,
|
|
150
|
+
ref_input=ref_input,
|
|
151
|
+
ref_weight=ref_weight,
|
|
152
|
+
ref_bias=ref_bias,
|
|
153
|
+
average_log_prob=average_log_prob,
|
|
154
|
+
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def backward(ctx, *grad_output):
|
|
160
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
165
|
+
"""
|
|
166
|
+
Fused linear layer with DPO loss.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
ignore_index: int = -100,
|
|
172
|
+
beta: float = 0.1,
|
|
173
|
+
compute_nll_loss: bool = False,
|
|
174
|
+
compiled: bool = True,
|
|
175
|
+
use_ref_model: bool = True,
|
|
176
|
+
average_log_prob: bool = False,
|
|
177
|
+
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Args:
|
|
182
|
+
ignore_index (int): Index to ignore in the loss.
|
|
183
|
+
beta (float): Weight for the odds ratio loss.
|
|
184
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
185
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
186
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
187
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
188
|
+
chunk_size (int): Size of chunks for processing.
|
|
189
|
+
"""
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.ignore_index = ignore_index
|
|
192
|
+
self.beta = beta
|
|
193
|
+
self.compute_nll_loss = compute_nll_loss
|
|
194
|
+
self.compiled = compiled
|
|
195
|
+
self.use_ref_model = use_ref_model
|
|
196
|
+
self.average_log_prob = average_log_prob
|
|
197
|
+
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
202
|
+
|
|
203
|
+
def forward(
|
|
204
|
+
self,
|
|
205
|
+
lin_weight,
|
|
206
|
+
_input,
|
|
207
|
+
target,
|
|
208
|
+
bias=None,
|
|
209
|
+
ref_input=None,
|
|
210
|
+
ref_weight=None,
|
|
211
|
+
ref_bias=None,
|
|
212
|
+
):
|
|
213
|
+
return LigerFusedLinearDPOFunction.apply(
|
|
214
|
+
_input,
|
|
215
|
+
lin_weight,
|
|
216
|
+
target,
|
|
217
|
+
bias,
|
|
218
|
+
ref_input,
|
|
219
|
+
ref_weight,
|
|
220
|
+
ref_bias,
|
|
221
|
+
self.ignore_index,
|
|
222
|
+
self.beta,
|
|
223
|
+
self.compute_nll_loss,
|
|
224
|
+
self.compiled,
|
|
225
|
+
self.use_ref_model,
|
|
226
|
+
self.average_log_prob,
|
|
227
|
+
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
229
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
2
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
3
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
4
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
5
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
|
6
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
|
7
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
|
8
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
9
|
+
|
|
10
|
+
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
11
|
+
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
12
|
+
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
14
|
+
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
15
|
+
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
16
|
+
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
17
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|