liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ target,
24
+ bias=None,
25
+ chunk_size=1,
26
+ ignore_index=-100,
27
+ alpha=1.0,
28
+ beta=0.1,
29
+ compute_nll_loss=True,
30
+ nll_target=None,
31
+ compiled=True,
32
+ use_ref_model=False,
33
+ ref_input=None,
34
+ ref_weight=None,
35
+ ref_bias=None,
36
+ average_log_prob=True,
37
+ **loss_kwargs,
38
+ ):
39
+ """
40
+ Base class for fused linear layer with preference loss.
41
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
42
+
43
+ The mental model is:
44
+
45
+ forward()
46
+ ├── Loop over chunks
47
+ └── compute_loss()
48
+ ├── chunk_forward() # Compute logits and log probs
49
+ └── prefer_loss() # Calculate preference loss
50
+
51
+ Args:
52
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
53
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
54
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
55
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
56
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
57
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
58
+ ignore_index (int): Index to ignore for loss computation.
59
+ alpha (float): Weight for the NLL loss.
60
+ beta (float): Weight for the preference loss.
61
+ compute_nll_loss (bool): Whether to compute NLL loss.
62
+ nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
63
+ compiled (bool): Whether to use torch compile for chunk accumulation.
64
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
65
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
66
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
67
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
68
+ loss_kwargs (dict): Other possible arguments that a loss function might need
69
+ """
70
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
71
+ CHUNK_SIZE = chunk_size
72
+
73
+ # Gradients to be accumulated
74
+ grad_weight = torch.zeros_like(weight)
75
+ grad_chosen_inputs = []
76
+ grad_rejected_inputs = []
77
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
78
+
79
+ # Loss to be accumulated
80
+ loss_acc = torch.zeros((), device=_input.device)
81
+
82
+ # Metrics to be recorded
83
+ policy_chosen_logps = []
84
+ policy_rejected_logps = []
85
+ policy_chosen_logits_mean = torch.zeros((), device=_input.device)
86
+ policy_rejected_logits_mean = torch.zeros((), device=_input.device)
87
+ policy_nll_loss = torch.zeros((), device=_input.device)
88
+ aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
89
+
90
+ compute_loss = partial(
91
+ LigerFusedLinearPreferenceBase._compute_loss,
92
+ preference_loss_fn=cls.preference_loss_fn,
93
+ ignore_index=ignore_index,
94
+ alpha=alpha,
95
+ beta=beta,
96
+ compute_nll_loss=compute_nll_loss,
97
+ full_target=target,
98
+ use_ref_model=use_ref_model,
99
+ ref_weight=ref_weight,
100
+ ref_bias=ref_bias,
101
+ full_nll_target=nll_target,
102
+ average_log_prob=average_log_prob,
103
+ **loss_kwargs,
104
+ )
105
+
106
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
107
+ """
108
+ Fused forward and backward pass for a chunk of input and target.
109
+ """
110
+ if bias is not None:
111
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
112
+ input_chunk,
113
+ weight,
114
+ target_chunk,
115
+ bias,
116
+ ref_input_chunk=ref_input_chunk,
117
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
118
+ )
119
+ else:
120
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
121
+ input_chunk,
122
+ weight,
123
+ target_chunk,
124
+ ref_input_chunk=ref_input_chunk,
125
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
126
+ )
127
+
128
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
129
+ if bias is not None:
130
+ (
131
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
132
+ (
133
+ chunk_loss,
134
+ (
135
+ chunk_chosen_logps,
136
+ chunk_rejected_logps,
137
+ chunk_chosen_logits_mean,
138
+ chunk_rejected_logits_mean,
139
+ chunk_nll_loss,
140
+ *aux_outputs,
141
+ ),
142
+ ),
143
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
144
+ grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
145
+ else:
146
+ (
147
+ (chunk_grad_input, chunk_grad_weight),
148
+ (
149
+ chunk_loss,
150
+ (
151
+ chunk_chosen_logps,
152
+ chunk_rejected_logps,
153
+ chunk_chosen_logits_mean,
154
+ chunk_rejected_logits_mean,
155
+ chunk_nll_loss,
156
+ *aux_outputs,
157
+ ),
158
+ ),
159
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
160
+
161
+ # Accumulate gradients
162
+ grad_weight.add_(chunk_grad_weight)
163
+ grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
164
+ grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
165
+
166
+ # Accumulate loss
167
+ loss_acc.add_(chunk_loss)
168
+
169
+ # Accumulate metrics
170
+ policy_chosen_logps.append(chunk_chosen_logps)
171
+ policy_rejected_logps.append(chunk_rejected_logps)
172
+ policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
173
+ policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
174
+ policy_nll_loss.add_(chunk_nll_loss)
175
+
176
+ # aux_outputs
177
+ # Initialize storage for aux_outputs
178
+ if len(aggregated_aux_outputs) == 0:
179
+ for aux in aux_outputs:
180
+ if aux.ndim == 0:
181
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
182
+ else:
183
+ aggregated_aux_outputs.append([])
184
+
185
+ # Process each aux_output
186
+ for i, aux in enumerate(aux_outputs):
187
+ if aux.ndim == 0:
188
+ aggregated_aux_outputs[i].add_(aux)
189
+ else:
190
+ aggregated_aux_outputs[i].append(aux)
191
+
192
+ if compiled:
193
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
194
+
195
+ len_chosen = target.shape[0] // 2
196
+ chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
197
+ _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
198
+ _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
199
+ _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
200
+ _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
201
+
202
+ if nll_target is not None:
203
+ _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
204
+
205
+ if use_ref_model:
206
+ _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
207
+ _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
208
+
209
+ for (
210
+ chosen_input_chunk,
211
+ rejected_input_chunk,
212
+ chosen_target_chunk,
213
+ rejected_target_chunk,
214
+ ref_chosen_input_chunk,
215
+ ref_rejected_input_chunk,
216
+ chosen_nll_target_chunk,
217
+ ) in zip(
218
+ _chosen_input_chunks,
219
+ _rejected_input_chunks,
220
+ _chosen_target_chunks,
221
+ _rejected_target_chunks,
222
+ (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
+ (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
+ (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
+ ):
226
+ input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
227
+ ref_input_chunk = (
228
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
229
+ )
230
+ target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
231
+
232
+ # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
233
+ torch._dynamo.mark_dynamic(input_chunk, 1)
234
+ torch._dynamo.mark_dynamic(target_chunk, 1)
235
+ torch._dynamo.mark_dynamic(target, 1)
236
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
237
+ torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
238
+
239
+ # accumulate loss, gradients, and metrics
240
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
241
+
242
+ # combine grad_chosen_inputs and grad_rejected_inputs
243
+ grad_inputs = grad_chosen_inputs + grad_rejected_inputs
244
+ policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
245
+ policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
246
+
247
+ # Aggregate aux outputs lists into tensors
248
+ for i, aux in enumerate(aggregated_aux_outputs):
249
+ if isinstance(aux, list):
250
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
251
+
252
+ ctx.save_for_backward(
253
+ torch.cat(grad_inputs, dim=0),
254
+ grad_weight,
255
+ grad_bias,
256
+ )
257
+ return_vars = (
258
+ policy_chosen_logps,
259
+ policy_rejected_logps,
260
+ policy_chosen_logits_mean,
261
+ policy_rejected_logits_mean,
262
+ policy_nll_loss,
263
+ )
264
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
265
+
266
+ @staticmethod
267
+ def backward(ctx, *grad_output):
268
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
269
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
270
+ grad_input = grad_input * grad_output[0][0]
271
+ grad_weight = grad_weight * grad_output[0][0]
272
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
273
+
274
+ return grad_input, grad_weight, None, grad_bias, None, None, None, None
275
+
276
+ @staticmethod
277
+ def chunk_forward(
278
+ input_chunk,
279
+ weight,
280
+ target_chunk,
281
+ bias=None,
282
+ ignore_index=-100,
283
+ compute_nll_loss=True,
284
+ chosen_nll_target_chunk=None,
285
+ average_log_prob=True,
286
+ ):
287
+ len_chosen_chunk = target_chunk.shape[0] // 2
288
+ logits_chunk = input_chunk @ weight.t()
289
+ if bias is not None:
290
+ logits_chunk = logits_chunk + bias
291
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
292
+
293
+ chosen_nll_loss = 0.0
294
+ if compute_nll_loss:
295
+ nll_labels = (
296
+ chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
297
+ )
298
+ chosen_nll_loss = F.nll_loss(
299
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
300
+ nll_labels.view(-1),
301
+ reduction="sum",
302
+ ignore_index=ignore_index,
303
+ )
304
+
305
+ loss_mask = target_chunk != ignore_index
306
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
307
+
308
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
309
+ if average_log_prob:
310
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
311
+ else:
312
+ log_prob = (per_token_logps * loss_mask).sum(-1)
313
+
314
+ chosen_logps = log_prob[:len_chosen_chunk]
315
+ rejected_logps = log_prob[len_chosen_chunk:]
316
+
317
+ chosen_logits = logits_chunk[:len_chosen_chunk]
318
+ rejected_logits = logits_chunk[len_chosen_chunk:]
319
+
320
+ return (
321
+ chosen_logps,
322
+ rejected_logps,
323
+ chosen_logits,
324
+ rejected_logits,
325
+ chosen_nll_loss,
326
+ )
327
+
328
+ @staticmethod
329
+ def _compute_loss(
330
+ input_chunk,
331
+ weight,
332
+ target_chunk,
333
+ bias=None,
334
+ preference_loss_fn=None,
335
+ full_target=None,
336
+ ignore_index=-100,
337
+ alpha=1.0,
338
+ beta=0.1,
339
+ compute_nll_loss=True,
340
+ use_ref_model=False,
341
+ ref_input_chunk=None,
342
+ ref_weight=None,
343
+ ref_bias=None,
344
+ full_nll_target=None,
345
+ chosen_nll_target_chunk=None,
346
+ average_log_prob=True,
347
+ **loss_kwargs,
348
+ ):
349
+ """
350
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
351
+ Args:
352
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
353
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
354
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
355
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
356
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
357
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
358
+ ignore_index (int): Index to ignore for loss computation.
359
+ alpha (float): Weight for the NLL loss.
360
+ beta (float): Weight for the preference loss.
361
+ compute_nll_loss (bool): Whether to compute NLL loss.
362
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
363
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
364
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
365
+ full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
366
+ chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
367
+ average_log_prob (bool): Whether to average log probabilities or the sum.
368
+ loss_kwargs (dict): Additional arguments for the loss function.
369
+ """
370
+ (
371
+ chosen_logps,
372
+ rejected_logps,
373
+ chosen_logits,
374
+ rejected_logits,
375
+ chosen_nll_loss,
376
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
377
+ input_chunk,
378
+ weight,
379
+ target_chunk,
380
+ bias=bias,
381
+ ignore_index=ignore_index,
382
+ compute_nll_loss=compute_nll_loss,
383
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
384
+ average_log_prob=average_log_prob,
385
+ )
386
+ if full_nll_target is not None:
387
+ chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
388
+ else:
389
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
390
+
391
+ chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
392
+ rejected_logits_mean = rejected_logits.sum() / (
393
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
394
+ )
395
+
396
+ if use_ref_model:
397
+ with torch.no_grad():
398
+ (
399
+ ref_chosen_logps,
400
+ ref_rejected_logps,
401
+ _,
402
+ _,
403
+ _,
404
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
405
+ ref_input_chunk,
406
+ ref_weight,
407
+ target_chunk,
408
+ ref_bias,
409
+ ignore_index=ignore_index,
410
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
411
+ chosen_nll_target_chunk=None,
412
+ average_log_prob=average_log_prob,
413
+ )
414
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
415
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
416
+
417
+ preference_loss_outputs = preference_loss_fn(
418
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
419
+ )
420
+ if isinstance(preference_loss_outputs, tuple):
421
+ preference_loss, *aux_outputs = preference_loss_outputs
422
+ else:
423
+ preference_loss, aux_outputs = preference_loss_outputs, []
424
+
425
+ loss = alpha * chosen_nll_loss + preference_loss
426
+ return_vars = (
427
+ chosen_logps,
428
+ rejected_logps,
429
+ chosen_logits_mean,
430
+ rejected_logits_mean,
431
+ chosen_nll_loss,
432
+ )
433
+ return loss, (*return_vars, *aux_outputs)