liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
cls,
|
|
20
|
+
ctx,
|
|
21
|
+
_input,
|
|
22
|
+
weight,
|
|
23
|
+
target,
|
|
24
|
+
bias=None,
|
|
25
|
+
chunk_size=1,
|
|
26
|
+
ignore_index=-100,
|
|
27
|
+
alpha=1.0,
|
|
28
|
+
beta=0.1,
|
|
29
|
+
compute_nll_loss=True,
|
|
30
|
+
nll_target=None,
|
|
31
|
+
compiled=True,
|
|
32
|
+
use_ref_model=False,
|
|
33
|
+
ref_input=None,
|
|
34
|
+
ref_weight=None,
|
|
35
|
+
ref_bias=None,
|
|
36
|
+
average_log_prob=True,
|
|
37
|
+
**loss_kwargs,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Base class for fused linear layer with preference loss.
|
|
41
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
|
42
|
+
|
|
43
|
+
The mental model is:
|
|
44
|
+
|
|
45
|
+
forward()
|
|
46
|
+
├── Loop over chunks
|
|
47
|
+
└── compute_loss()
|
|
48
|
+
├── chunk_forward() # Compute logits and log probs
|
|
49
|
+
└── prefer_loss() # Calculate preference loss
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
|
53
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
54
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
|
55
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
56
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
57
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
|
58
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
59
|
+
alpha (float): Weight for the NLL loss.
|
|
60
|
+
beta (float): Weight for the preference loss.
|
|
61
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
|
62
|
+
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
|
|
63
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
64
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
65
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
66
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
67
|
+
average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
|
|
68
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
69
|
+
"""
|
|
70
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
71
|
+
CHUNK_SIZE = chunk_size
|
|
72
|
+
|
|
73
|
+
# Gradients to be accumulated
|
|
74
|
+
grad_weight = torch.zeros_like(weight)
|
|
75
|
+
grad_chosen_inputs = []
|
|
76
|
+
grad_rejected_inputs = []
|
|
77
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
|
78
|
+
|
|
79
|
+
# Loss to be accumulated
|
|
80
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
81
|
+
|
|
82
|
+
# Metrics to be recorded
|
|
83
|
+
policy_chosen_logps = []
|
|
84
|
+
policy_rejected_logps = []
|
|
85
|
+
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
|
|
86
|
+
policy_rejected_logits_mean = torch.zeros((), device=_input.device)
|
|
87
|
+
policy_nll_loss = torch.zeros((), device=_input.device)
|
|
88
|
+
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
|
|
89
|
+
|
|
90
|
+
compute_loss = partial(
|
|
91
|
+
LigerFusedLinearPreferenceBase._compute_loss,
|
|
92
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
93
|
+
ignore_index=ignore_index,
|
|
94
|
+
alpha=alpha,
|
|
95
|
+
beta=beta,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
97
|
+
full_target=target,
|
|
98
|
+
use_ref_model=use_ref_model,
|
|
99
|
+
ref_weight=ref_weight,
|
|
100
|
+
ref_bias=ref_bias,
|
|
101
|
+
full_nll_target=nll_target,
|
|
102
|
+
average_log_prob=average_log_prob,
|
|
103
|
+
**loss_kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
|
|
107
|
+
"""
|
|
108
|
+
Fused forward and backward pass for a chunk of input and target.
|
|
109
|
+
"""
|
|
110
|
+
if bias is not None:
|
|
111
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
|
|
112
|
+
input_chunk,
|
|
113
|
+
weight,
|
|
114
|
+
target_chunk,
|
|
115
|
+
bias,
|
|
116
|
+
ref_input_chunk=ref_input_chunk,
|
|
117
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
121
|
+
input_chunk,
|
|
122
|
+
weight,
|
|
123
|
+
target_chunk,
|
|
124
|
+
ref_input_chunk=ref_input_chunk,
|
|
125
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
|
|
129
|
+
if bias is not None:
|
|
130
|
+
(
|
|
131
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
|
132
|
+
(
|
|
133
|
+
chunk_loss,
|
|
134
|
+
(
|
|
135
|
+
chunk_chosen_logps,
|
|
136
|
+
chunk_rejected_logps,
|
|
137
|
+
chunk_chosen_logits_mean,
|
|
138
|
+
chunk_rejected_logits_mean,
|
|
139
|
+
chunk_nll_loss,
|
|
140
|
+
*aux_outputs,
|
|
141
|
+
),
|
|
142
|
+
),
|
|
143
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
144
|
+
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
|
145
|
+
else:
|
|
146
|
+
(
|
|
147
|
+
(chunk_grad_input, chunk_grad_weight),
|
|
148
|
+
(
|
|
149
|
+
chunk_loss,
|
|
150
|
+
(
|
|
151
|
+
chunk_chosen_logps,
|
|
152
|
+
chunk_rejected_logps,
|
|
153
|
+
chunk_chosen_logits_mean,
|
|
154
|
+
chunk_rejected_logits_mean,
|
|
155
|
+
chunk_nll_loss,
|
|
156
|
+
*aux_outputs,
|
|
157
|
+
),
|
|
158
|
+
),
|
|
159
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
160
|
+
|
|
161
|
+
# Accumulate gradients
|
|
162
|
+
grad_weight.add_(chunk_grad_weight)
|
|
163
|
+
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
|
164
|
+
grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
|
|
165
|
+
|
|
166
|
+
# Accumulate loss
|
|
167
|
+
loss_acc.add_(chunk_loss)
|
|
168
|
+
|
|
169
|
+
# Accumulate metrics
|
|
170
|
+
policy_chosen_logps.append(chunk_chosen_logps)
|
|
171
|
+
policy_rejected_logps.append(chunk_rejected_logps)
|
|
172
|
+
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
|
|
173
|
+
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
|
|
174
|
+
policy_nll_loss.add_(chunk_nll_loss)
|
|
175
|
+
|
|
176
|
+
# aux_outputs
|
|
177
|
+
# Initialize storage for aux_outputs
|
|
178
|
+
if len(aggregated_aux_outputs) == 0:
|
|
179
|
+
for aux in aux_outputs:
|
|
180
|
+
if aux.ndim == 0:
|
|
181
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
182
|
+
else:
|
|
183
|
+
aggregated_aux_outputs.append([])
|
|
184
|
+
|
|
185
|
+
# Process each aux_output
|
|
186
|
+
for i, aux in enumerate(aux_outputs):
|
|
187
|
+
if aux.ndim == 0:
|
|
188
|
+
aggregated_aux_outputs[i].add_(aux)
|
|
189
|
+
else:
|
|
190
|
+
aggregated_aux_outputs[i].append(aux)
|
|
191
|
+
|
|
192
|
+
if compiled:
|
|
193
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
194
|
+
|
|
195
|
+
len_chosen = target.shape[0] // 2
|
|
196
|
+
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
|
197
|
+
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
|
|
198
|
+
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
|
|
199
|
+
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
|
200
|
+
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
|
201
|
+
|
|
202
|
+
if nll_target is not None:
|
|
203
|
+
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
|
|
204
|
+
|
|
205
|
+
if use_ref_model:
|
|
206
|
+
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
|
207
|
+
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
|
208
|
+
|
|
209
|
+
for (
|
|
210
|
+
chosen_input_chunk,
|
|
211
|
+
rejected_input_chunk,
|
|
212
|
+
chosen_target_chunk,
|
|
213
|
+
rejected_target_chunk,
|
|
214
|
+
ref_chosen_input_chunk,
|
|
215
|
+
ref_rejected_input_chunk,
|
|
216
|
+
chosen_nll_target_chunk,
|
|
217
|
+
) in zip(
|
|
218
|
+
_chosen_input_chunks,
|
|
219
|
+
_rejected_input_chunks,
|
|
220
|
+
_chosen_target_chunks,
|
|
221
|
+
_rejected_target_chunks,
|
|
222
|
+
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
|
+
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
|
+
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
+
):
|
|
226
|
+
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
227
|
+
ref_input_chunk = (
|
|
228
|
+
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
|
|
229
|
+
)
|
|
230
|
+
target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
|
|
231
|
+
|
|
232
|
+
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
|
|
233
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
234
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
235
|
+
torch._dynamo.mark_dynamic(target, 1)
|
|
236
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
237
|
+
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
|
|
238
|
+
|
|
239
|
+
# accumulate loss, gradients, and metrics
|
|
240
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
241
|
+
|
|
242
|
+
# combine grad_chosen_inputs and grad_rejected_inputs
|
|
243
|
+
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
|
244
|
+
policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
|
|
245
|
+
policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
|
|
246
|
+
|
|
247
|
+
# Aggregate aux outputs lists into tensors
|
|
248
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
|
249
|
+
if isinstance(aux, list):
|
|
250
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
|
251
|
+
|
|
252
|
+
ctx.save_for_backward(
|
|
253
|
+
torch.cat(grad_inputs, dim=0),
|
|
254
|
+
grad_weight,
|
|
255
|
+
grad_bias,
|
|
256
|
+
)
|
|
257
|
+
return_vars = (
|
|
258
|
+
policy_chosen_logps,
|
|
259
|
+
policy_rejected_logps,
|
|
260
|
+
policy_chosen_logits_mean,
|
|
261
|
+
policy_rejected_logits_mean,
|
|
262
|
+
policy_nll_loss,
|
|
263
|
+
)
|
|
264
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def backward(ctx, *grad_output):
|
|
268
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
269
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
|
270
|
+
grad_input = grad_input * grad_output[0][0]
|
|
271
|
+
grad_weight = grad_weight * grad_output[0][0]
|
|
272
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
273
|
+
|
|
274
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None, None
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def chunk_forward(
|
|
278
|
+
input_chunk,
|
|
279
|
+
weight,
|
|
280
|
+
target_chunk,
|
|
281
|
+
bias=None,
|
|
282
|
+
ignore_index=-100,
|
|
283
|
+
compute_nll_loss=True,
|
|
284
|
+
chosen_nll_target_chunk=None,
|
|
285
|
+
average_log_prob=True,
|
|
286
|
+
):
|
|
287
|
+
len_chosen_chunk = target_chunk.shape[0] // 2
|
|
288
|
+
logits_chunk = input_chunk @ weight.t()
|
|
289
|
+
if bias is not None:
|
|
290
|
+
logits_chunk = logits_chunk + bias
|
|
291
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
292
|
+
|
|
293
|
+
chosen_nll_loss = 0.0
|
|
294
|
+
if compute_nll_loss:
|
|
295
|
+
nll_labels = (
|
|
296
|
+
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
|
|
297
|
+
)
|
|
298
|
+
chosen_nll_loss = F.nll_loss(
|
|
299
|
+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
|
300
|
+
nll_labels.view(-1),
|
|
301
|
+
reduction="sum",
|
|
302
|
+
ignore_index=ignore_index,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
loss_mask = target_chunk != ignore_index
|
|
306
|
+
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
|
307
|
+
|
|
308
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
309
|
+
if average_log_prob:
|
|
310
|
+
log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
311
|
+
else:
|
|
312
|
+
log_prob = (per_token_logps * loss_mask).sum(-1)
|
|
313
|
+
|
|
314
|
+
chosen_logps = log_prob[:len_chosen_chunk]
|
|
315
|
+
rejected_logps = log_prob[len_chosen_chunk:]
|
|
316
|
+
|
|
317
|
+
chosen_logits = logits_chunk[:len_chosen_chunk]
|
|
318
|
+
rejected_logits = logits_chunk[len_chosen_chunk:]
|
|
319
|
+
|
|
320
|
+
return (
|
|
321
|
+
chosen_logps,
|
|
322
|
+
rejected_logps,
|
|
323
|
+
chosen_logits,
|
|
324
|
+
rejected_logits,
|
|
325
|
+
chosen_nll_loss,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def _compute_loss(
|
|
330
|
+
input_chunk,
|
|
331
|
+
weight,
|
|
332
|
+
target_chunk,
|
|
333
|
+
bias=None,
|
|
334
|
+
preference_loss_fn=None,
|
|
335
|
+
full_target=None,
|
|
336
|
+
ignore_index=-100,
|
|
337
|
+
alpha=1.0,
|
|
338
|
+
beta=0.1,
|
|
339
|
+
compute_nll_loss=True,
|
|
340
|
+
use_ref_model=False,
|
|
341
|
+
ref_input_chunk=None,
|
|
342
|
+
ref_weight=None,
|
|
343
|
+
ref_bias=None,
|
|
344
|
+
full_nll_target=None,
|
|
345
|
+
chosen_nll_target_chunk=None,
|
|
346
|
+
average_log_prob=True,
|
|
347
|
+
**loss_kwargs,
|
|
348
|
+
):
|
|
349
|
+
"""
|
|
350
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
|
351
|
+
Args:
|
|
352
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
353
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
|
354
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
355
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
|
356
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
357
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
|
358
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
359
|
+
alpha (float): Weight for the NLL loss.
|
|
360
|
+
beta (float): Weight for the preference loss.
|
|
361
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
|
362
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
363
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
364
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
365
|
+
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
|
|
366
|
+
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
|
|
367
|
+
average_log_prob (bool): Whether to average log probabilities or the sum.
|
|
368
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
369
|
+
"""
|
|
370
|
+
(
|
|
371
|
+
chosen_logps,
|
|
372
|
+
rejected_logps,
|
|
373
|
+
chosen_logits,
|
|
374
|
+
rejected_logits,
|
|
375
|
+
chosen_nll_loss,
|
|
376
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
|
377
|
+
input_chunk,
|
|
378
|
+
weight,
|
|
379
|
+
target_chunk,
|
|
380
|
+
bias=bias,
|
|
381
|
+
ignore_index=ignore_index,
|
|
382
|
+
compute_nll_loss=compute_nll_loss,
|
|
383
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
384
|
+
average_log_prob=average_log_prob,
|
|
385
|
+
)
|
|
386
|
+
if full_nll_target is not None:
|
|
387
|
+
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
|
|
388
|
+
else:
|
|
389
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
|
390
|
+
|
|
391
|
+
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
|
392
|
+
rejected_logits_mean = rejected_logits.sum() / (
|
|
393
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if use_ref_model:
|
|
397
|
+
with torch.no_grad():
|
|
398
|
+
(
|
|
399
|
+
ref_chosen_logps,
|
|
400
|
+
ref_rejected_logps,
|
|
401
|
+
_,
|
|
402
|
+
_,
|
|
403
|
+
_,
|
|
404
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
|
405
|
+
ref_input_chunk,
|
|
406
|
+
ref_weight,
|
|
407
|
+
target_chunk,
|
|
408
|
+
ref_bias,
|
|
409
|
+
ignore_index=ignore_index,
|
|
410
|
+
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
|
411
|
+
chosen_nll_target_chunk=None,
|
|
412
|
+
average_log_prob=average_log_prob,
|
|
413
|
+
)
|
|
414
|
+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
|
415
|
+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
|
416
|
+
|
|
417
|
+
preference_loss_outputs = preference_loss_fn(
|
|
418
|
+
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
|
419
|
+
)
|
|
420
|
+
if isinstance(preference_loss_outputs, tuple):
|
|
421
|
+
preference_loss, *aux_outputs = preference_loss_outputs
|
|
422
|
+
else:
|
|
423
|
+
preference_loss, aux_outputs = preference_loss_outputs, []
|
|
424
|
+
|
|
425
|
+
loss = alpha * chosen_nll_loss + preference_loss
|
|
426
|
+
return_vars = (
|
|
427
|
+
chosen_logps,
|
|
428
|
+
rejected_logps,
|
|
429
|
+
chosen_logits_mean,
|
|
430
|
+
rejected_logits_mean,
|
|
431
|
+
chosen_nll_loss,
|
|
432
|
+
)
|
|
433
|
+
return loss, (*return_vars, *aux_outputs)
|