liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
13
|
+
@staticmethod
|
|
14
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
|
15
|
+
"""
|
|
16
|
+
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
|
17
|
+
Args:
|
|
18
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
19
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
20
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
21
|
+
Returns:
|
|
22
|
+
torch.Tensor: Jensen-Shannon Divergence loss
|
|
23
|
+
"""
|
|
24
|
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
25
|
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
26
|
+
|
|
27
|
+
if beta == 0:
|
|
28
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
29
|
+
elif beta == 1:
|
|
30
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
31
|
+
else:
|
|
32
|
+
# Compute probabilities (only required for mean calculation)
|
|
33
|
+
log_mean_probs = torch.logsumexp(
|
|
34
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
38
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
39
|
+
|
|
40
|
+
# JSD is the weighted average of the KL divergences
|
|
41
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
42
|
+
return jsd_loss
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def forward(
|
|
46
|
+
cls,
|
|
47
|
+
ctx,
|
|
48
|
+
student_input: torch.Tensor,
|
|
49
|
+
student_weight: torch.Tensor,
|
|
50
|
+
teacher_input: torch.Tensor,
|
|
51
|
+
teacher_weight: torch.Tensor,
|
|
52
|
+
true_labels: torch.LongTensor,
|
|
53
|
+
student_bias: torch.Tensor,
|
|
54
|
+
teacher_bias: torch.Tensor,
|
|
55
|
+
weight_hard_loss: float = 0.5,
|
|
56
|
+
weight_soft_loss: float = 0.5,
|
|
57
|
+
beta: float = 0.5,
|
|
58
|
+
ignore_index: int = -100,
|
|
59
|
+
temperature: float = 1.0,
|
|
60
|
+
compiled: bool = True,
|
|
61
|
+
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Fused linear layer with JSD distillation loss.
|
|
66
|
+
Args:
|
|
67
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
|
|
68
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
|
|
69
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
|
|
70
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
|
|
71
|
+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
72
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
73
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
74
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
75
|
+
ignore_index (int): Index to ignore in loss computation
|
|
76
|
+
temperature (float): Temperature for softening/sharpening distributions
|
|
77
|
+
compiled (bool): Whether to use torch compile
|
|
78
|
+
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
80
|
+
Returns:
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
82
|
+
"""
|
|
83
|
+
return super().forward(
|
|
84
|
+
cls=cls,
|
|
85
|
+
ctx=ctx,
|
|
86
|
+
student_input=student_input,
|
|
87
|
+
student_weight=student_weight,
|
|
88
|
+
teacher_input=teacher_input,
|
|
89
|
+
teacher_weight=teacher_weight,
|
|
90
|
+
target=true_labels,
|
|
91
|
+
student_bias=student_bias,
|
|
92
|
+
teacher_bias=teacher_bias,
|
|
93
|
+
chunk_size=chunk_size,
|
|
94
|
+
weight_hard_loss=weight_hard_loss,
|
|
95
|
+
weight_soft_loss=weight_soft_loss,
|
|
96
|
+
beta=beta,
|
|
97
|
+
ignore_index=ignore_index,
|
|
98
|
+
temperature=temperature,
|
|
99
|
+
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
106
|
+
|
|
107
|
+
return (
|
|
108
|
+
*grads,
|
|
109
|
+
None, # teacher_bias
|
|
110
|
+
None, # weight_hard_loss
|
|
111
|
+
None, # weight_soft_loss
|
|
112
|
+
None, # beta
|
|
113
|
+
None, # ignore_index
|
|
114
|
+
None, # temperature
|
|
115
|
+
None, # compiled
|
|
116
|
+
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
122
|
+
"""
|
|
123
|
+
Fused linear layer with JSD distillation loss.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
weight_hard_loss: float = 0.5,
|
|
129
|
+
weight_soft_loss: float = 0.5,
|
|
130
|
+
beta: float = 0.5,
|
|
131
|
+
ignore_index: int = -100,
|
|
132
|
+
temperature: float = 1.0,
|
|
133
|
+
compiled: bool = True,
|
|
134
|
+
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Args:
|
|
139
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
140
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
141
|
+
ignore_index (int): Index to ignore in the loss
|
|
142
|
+
temperature (float): Temperature for softening distributions
|
|
143
|
+
compiled (bool): Whether to use torch compile
|
|
144
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
145
|
+
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
147
|
+
"""
|
|
148
|
+
super().__init__()
|
|
149
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
150
|
+
self.weight_hard_loss = weight_hard_loss
|
|
151
|
+
self.weight_soft_loss = weight_soft_loss
|
|
152
|
+
self.ignore_index = ignore_index
|
|
153
|
+
self.temperature = temperature
|
|
154
|
+
self.compiled = compiled
|
|
155
|
+
self.beta = beta
|
|
156
|
+
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
student_input: torch.Tensor,
|
|
162
|
+
student_weight: torch.Tensor,
|
|
163
|
+
teacher_input: torch.Tensor,
|
|
164
|
+
teacher_weight: torch.Tensor,
|
|
165
|
+
true_labels: torch.LongTensor,
|
|
166
|
+
student_bias: torch.Tensor = None,
|
|
167
|
+
teacher_bias: torch.Tensor = None,
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
169
|
+
"""
|
|
170
|
+
Compute the JSD distillation loss.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
student_input (torch.Tensor): Student input tensor
|
|
174
|
+
student_weight (torch.Tensor): Student weight tensor
|
|
175
|
+
teacher_input (torch.Tensor): Teacher input tensor
|
|
176
|
+
teacher_weight (torch.Tensor): Teacher weight tensor
|
|
177
|
+
true_labels (torch.LongTensor): Target labels tensor
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
183
|
+
"""
|
|
184
|
+
return LigerFusedLinearJSDFunction.apply(
|
|
185
|
+
student_input,
|
|
186
|
+
student_weight,
|
|
187
|
+
teacher_input,
|
|
188
|
+
teacher_weight,
|
|
189
|
+
true_labels,
|
|
190
|
+
student_bias,
|
|
191
|
+
teacher_bias,
|
|
192
|
+
self.weight_hard_loss,
|
|
193
|
+
self.weight_soft_loss,
|
|
194
|
+
self.beta,
|
|
195
|
+
self.ignore_index,
|
|
196
|
+
self.temperature,
|
|
197
|
+
self.compiled,
|
|
198
|
+
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
200
|
+
)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
log_prob_chunk,
|
|
11
|
+
preference_labels_chunk,
|
|
12
|
+
full_target,
|
|
13
|
+
ref_log_prob_chunk=None,
|
|
14
|
+
beta=0.1,
|
|
15
|
+
kl=None,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Implements the Kahneman-Tversky Optimization (KTO) loss function.
|
|
19
|
+
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
|
|
20
|
+
https://arxiv.org/abs/2402.01306
|
|
21
|
+
|
|
22
|
+
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
|
|
23
|
+
from behavioral economics, which models how humans make decisions under uncertainty.
|
|
24
|
+
The loss function is asymmetric, treating gains and losses differently, similar to
|
|
25
|
+
human decision-making patterns.
|
|
26
|
+
|
|
27
|
+
Formula:
|
|
28
|
+
When y is chosen:
|
|
29
|
+
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
|
|
30
|
+
When y is rejected:
|
|
31
|
+
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
|
|
32
|
+
|
|
33
|
+
Where:
|
|
34
|
+
- σ: Sigmoid function
|
|
35
|
+
- β: Temperature parameter controlling the strength of the preference signal
|
|
36
|
+
- π(x): Policy (current model)
|
|
37
|
+
- π₀(x): Reference policy (reference model)
|
|
38
|
+
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
|
|
39
|
+
|
|
40
|
+
The loss encourages the model to:
|
|
41
|
+
1. Assign higher probability to chosen responses
|
|
42
|
+
2. Assign lower probability to rejected responses
|
|
43
|
+
3. Maintain reasonable distance from the reference model
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
|
+
full_target: Non chunked full target tensor
|
|
49
|
+
ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
|
+
beta: Weight for the KTO loss
|
|
51
|
+
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
52
|
+
Returns:
|
|
53
|
+
- loss: The KTO loss value
|
|
54
|
+
"""
|
|
55
|
+
if ref_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = log_prob_chunk - ref_log_prob_chunk
|
|
57
|
+
else:
|
|
58
|
+
logratios_chunk = log_prob_chunk
|
|
59
|
+
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
60
|
+
if kl is not None:
|
|
61
|
+
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
62
|
+
else:
|
|
63
|
+
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
64
|
+
|
|
65
|
+
rewards = beta * logratios_chunk
|
|
66
|
+
chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
|
|
67
|
+
rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
68
|
+
|
|
69
|
+
return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def forward(
|
|
73
|
+
cls,
|
|
74
|
+
ctx,
|
|
75
|
+
_input,
|
|
76
|
+
weight,
|
|
77
|
+
target,
|
|
78
|
+
preference_labels,
|
|
79
|
+
bias=None,
|
|
80
|
+
ref_input=None,
|
|
81
|
+
ref_weight=None,
|
|
82
|
+
ref_bias=None,
|
|
83
|
+
kl=None,
|
|
84
|
+
ignore_index=-100,
|
|
85
|
+
beta=0.1,
|
|
86
|
+
compiled=True,
|
|
87
|
+
use_ref_model=True,
|
|
88
|
+
average_log_prob=False,
|
|
89
|
+
chunk_size=1,
|
|
90
|
+
):
|
|
91
|
+
"""
|
|
92
|
+
Fused linear layer with KTO loss.
|
|
93
|
+
Args:
|
|
94
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
95
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
96
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
97
|
+
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
|
|
98
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
99
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
100
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
101
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
102
|
+
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
|
|
103
|
+
ignore_index (int): Index to ignore in loss computation
|
|
104
|
+
beta (float): Temperature parameter for the KTO loss
|
|
105
|
+
compiled (bool): Whether to use torch compile
|
|
106
|
+
use_ref_model (bool): Whether to use a reference model
|
|
107
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
108
|
+
chunk_size (int): Size of chunks for processing
|
|
109
|
+
Returns:
|
|
110
|
+
torch.Tensor: Computed loss
|
|
111
|
+
"""
|
|
112
|
+
return super().forward(
|
|
113
|
+
cls=cls,
|
|
114
|
+
ctx=ctx,
|
|
115
|
+
_input=_input,
|
|
116
|
+
weight=weight,
|
|
117
|
+
target=target,
|
|
118
|
+
preference_labels=preference_labels,
|
|
119
|
+
bias=bias,
|
|
120
|
+
ignore_index=ignore_index,
|
|
121
|
+
beta=beta,
|
|
122
|
+
compiled=compiled,
|
|
123
|
+
use_ref_model=use_ref_model,
|
|
124
|
+
ref_input=ref_input,
|
|
125
|
+
ref_weight=ref_weight,
|
|
126
|
+
ref_bias=ref_bias,
|
|
127
|
+
average_log_prob=average_log_prob,
|
|
128
|
+
kl=kl,
|
|
129
|
+
chunk_size=chunk_size,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def backward(ctx, *grad_output):
|
|
134
|
+
grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
|
|
135
|
+
return (
|
|
136
|
+
*grads,
|
|
137
|
+
None,
|
|
138
|
+
None,
|
|
139
|
+
None,
|
|
140
|
+
None,
|
|
141
|
+
None,
|
|
142
|
+
None,
|
|
143
|
+
None,
|
|
144
|
+
None,
|
|
145
|
+
None,
|
|
146
|
+
None,
|
|
147
|
+
None,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
152
|
+
"""
|
|
153
|
+
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
ignore_index: int = -100,
|
|
159
|
+
beta: float = 0.1,
|
|
160
|
+
compiled: bool = True,
|
|
161
|
+
use_ref_model: bool = False,
|
|
162
|
+
average_log_prob: bool = False,
|
|
163
|
+
chunk_size: int = 1,
|
|
164
|
+
):
|
|
165
|
+
"""
|
|
166
|
+
Args:
|
|
167
|
+
ignore_index (int): Index to ignore in the loss calculation
|
|
168
|
+
beta (float): Temperature parameter for the KTO loss
|
|
169
|
+
compiled (bool): Whether to use compiled operations
|
|
170
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
171
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
172
|
+
chunk_size (int): Size of chunks for processing
|
|
173
|
+
"""
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.ignore_index = ignore_index
|
|
176
|
+
self.beta = beta
|
|
177
|
+
self.compiled = compiled
|
|
178
|
+
self.use_ref_model = use_ref_model
|
|
179
|
+
self.average_log_prob = average_log_prob
|
|
180
|
+
self.chunk_size = chunk_size
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self,
|
|
184
|
+
_input,
|
|
185
|
+
lin_weight,
|
|
186
|
+
target,
|
|
187
|
+
bias=None,
|
|
188
|
+
preference_labels=None,
|
|
189
|
+
ref_input=None,
|
|
190
|
+
ref_weight=None,
|
|
191
|
+
ref_bias=None,
|
|
192
|
+
kl=None,
|
|
193
|
+
):
|
|
194
|
+
return LigerFusedLinearKTOFunction.apply(
|
|
195
|
+
_input,
|
|
196
|
+
lin_weight,
|
|
197
|
+
target,
|
|
198
|
+
preference_labels,
|
|
199
|
+
bias,
|
|
200
|
+
ref_input,
|
|
201
|
+
ref_weight,
|
|
202
|
+
ref_bias,
|
|
203
|
+
kl,
|
|
204
|
+
self.ignore_index,
|
|
205
|
+
self.beta,
|
|
206
|
+
self.compiled,
|
|
207
|
+
self.use_ref_model,
|
|
208
|
+
self.average_log_prob,
|
|
209
|
+
self.chunk_size,
|
|
210
|
+
)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
10
|
+
"""
|
|
11
|
+
Paper: https://arxiv.org/pdf/2403.07691
|
|
12
|
+
|
|
13
|
+
Formula:
|
|
14
|
+
Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
|
|
15
|
+
where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
|
|
16
|
+
|
|
17
|
+
Where:
|
|
18
|
+
- P_θ(y|x): Policy (model) probability
|
|
19
|
+
- y_w: Chosen sequence
|
|
20
|
+
- y_l: Rejected sequence
|
|
21
|
+
- σ: Sigmoid function
|
|
22
|
+
- β: Weight for the odds ratio loss
|
|
23
|
+
- odds_θ: Odds function for the policy
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
27
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
28
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
|
29
|
+
beta (float): Weight for the odds ratio loss.
|
|
30
|
+
"""
|
|
31
|
+
log_odds = (chosen_logps - rejected_logps) - (
|
|
32
|
+
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
|
33
|
+
)
|
|
34
|
+
ratio = F.logsigmoid(log_odds)
|
|
35
|
+
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
|
36
|
+
|
|
37
|
+
chosen_rewards = beta * chosen_logps
|
|
38
|
+
rejected_rewards = beta * rejected_logps
|
|
39
|
+
|
|
40
|
+
log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
|
|
41
|
+
log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
|
|
42
|
+
|
|
43
|
+
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def forward(
|
|
47
|
+
cls,
|
|
48
|
+
ctx,
|
|
49
|
+
_input,
|
|
50
|
+
weight,
|
|
51
|
+
target,
|
|
52
|
+
bias=None,
|
|
53
|
+
ignore_index=-100,
|
|
54
|
+
beta=0.1,
|
|
55
|
+
compute_nll_loss=True,
|
|
56
|
+
nll_target=None,
|
|
57
|
+
compiled=True,
|
|
58
|
+
chunk_size=1,
|
|
59
|
+
):
|
|
60
|
+
"""
|
|
61
|
+
Fused linear layer with ORPO loss.
|
|
62
|
+
Args:
|
|
63
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
64
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
65
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
66
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
67
|
+
ignore_index (int): Index to ignore in loss computation
|
|
68
|
+
beta (float): Weight for the odds ratio loss
|
|
69
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
70
|
+
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
chunk_size (int): Size of chunks for processing
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Computed loss
|
|
75
|
+
"""
|
|
76
|
+
return super().forward(
|
|
77
|
+
cls=cls,
|
|
78
|
+
ctx=ctx,
|
|
79
|
+
_input=_input,
|
|
80
|
+
weight=weight,
|
|
81
|
+
target=target,
|
|
82
|
+
bias=bias,
|
|
83
|
+
ignore_index=ignore_index,
|
|
84
|
+
beta=beta,
|
|
85
|
+
compute_nll_loss=compute_nll_loss,
|
|
86
|
+
nll_target=nll_target,
|
|
87
|
+
compiled=compiled,
|
|
88
|
+
chunk_size=chunk_size,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def backward(ctx, *grad_output):
|
|
93
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
94
|
+
return *grads, None, None, None, None, None, None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
98
|
+
"""
|
|
99
|
+
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
ignore_index: int = -100,
|
|
105
|
+
beta: float = 0.1,
|
|
106
|
+
compute_nll_loss: bool = True,
|
|
107
|
+
compiled: bool = True,
|
|
108
|
+
chunk_size: int = 1,
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
Args:
|
|
112
|
+
ignore_index (int): Index to ignore in the loss.
|
|
113
|
+
beta (float): Weight for the odds ratio loss.
|
|
114
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
115
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
117
|
+
"""
|
|
118
|
+
super().__init__()
|
|
119
|
+
self.ignore_index = ignore_index
|
|
120
|
+
self.beta = beta
|
|
121
|
+
self.compute_nll_loss = compute_nll_loss
|
|
122
|
+
self.compiled = compiled
|
|
123
|
+
self.chunk_size = chunk_size
|
|
124
|
+
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
lin_weight,
|
|
128
|
+
_input,
|
|
129
|
+
target,
|
|
130
|
+
bias=None,
|
|
131
|
+
nll_target=None,
|
|
132
|
+
):
|
|
133
|
+
return LigerFusedLinearORPOFunction.apply(
|
|
134
|
+
_input,
|
|
135
|
+
lin_weight,
|
|
136
|
+
target,
|
|
137
|
+
bias,
|
|
138
|
+
self.ignore_index,
|
|
139
|
+
self.beta,
|
|
140
|
+
self.compute_nll_loss,
|
|
141
|
+
nll_target,
|
|
142
|
+
self.compiled,
|
|
143
|
+
self.chunk_size,
|
|
144
|
+
)
|