liger-kernel-nightly 0.5.2.dev20250108102127__py3-none-any.whl → 0.5.2.dev20250109023714__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -27,6 +27,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
27
27
  alpha=1.0,
28
28
  beta=0.1,
29
29
  compute_nll_loss=True,
30
+ nll_target=None,
30
31
  compiled=True,
31
32
  use_ref_model=False,
32
33
  ref_input=None,
@@ -58,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
58
59
  alpha (float): Weight for the NLL loss.
59
60
  beta (float): Weight for the preference loss.
60
61
  compute_nll_loss (bool): Whether to compute NLL loss.
62
+ nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
61
63
  compiled (bool): Whether to use torch compile for chunk accumulation.
62
64
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
63
65
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
@@ -96,11 +98,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
96
98
  use_ref_model=use_ref_model,
97
99
  ref_weight=ref_weight,
98
100
  ref_bias=ref_bias,
101
+ full_nll_target=nll_target,
99
102
  average_log_prob=average_log_prob,
100
103
  **loss_kwargs,
101
104
  )
102
105
 
103
- def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
106
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
104
107
  """
105
108
  Fused forward and backward pass for a chunk of input and target.
106
109
  """
@@ -111,13 +114,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
111
114
  target_chunk,
112
115
  bias,
113
116
  ref_input_chunk=ref_input_chunk,
117
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
114
118
  )
115
119
  else:
116
120
  return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
117
- input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
121
+ input_chunk,
122
+ weight,
123
+ target_chunk,
124
+ ref_input_chunk=ref_input_chunk,
125
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
118
126
  )
119
127
 
120
- def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
128
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
121
129
  if bias is not None:
122
130
  (
123
131
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
@@ -132,7 +140,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
132
140
  *aux_outputs,
133
141
  ),
134
142
  ),
135
- ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
143
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
136
144
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
137
145
  else:
138
146
  (
@@ -148,7 +156,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
148
156
  *aux_outputs,
149
157
  ),
150
158
  ),
151
- ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
159
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
152
160
 
153
161
  # Accumulate gradients
154
162
  grad_weight.add_(chunk_grad_weight)
@@ -191,6 +199,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
191
199
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
192
200
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
193
201
 
202
+ if nll_target is not None:
203
+ _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
204
+
194
205
  if use_ref_model:
195
206
  _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
196
207
  _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
@@ -202,6 +213,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
213
  rejected_target_chunk,
203
214
  ref_chosen_input_chunk,
204
215
  ref_rejected_input_chunk,
216
+ chosen_nll_target_chunk,
205
217
  ) in zip(
206
218
  _chosen_input_chunks,
207
219
  _rejected_input_chunks,
@@ -209,6 +221,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
209
221
  _rejected_target_chunks,
210
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
211
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
+ (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
212
225
  strict=False,
213
226
  ):
214
227
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
@@ -222,9 +235,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
235
  torch._dynamo.mark_dynamic(target_chunk, 1)
223
236
  torch._dynamo.mark_dynamic(target, 1)
224
237
  torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
238
+ torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
225
239
 
226
240
  # accumulate loss, gradients, and metrics
227
- accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
241
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
228
242
 
229
243
  # combine grad_chosen_inputs and grad_rejected_inputs
230
244
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -258,7 +272,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
258
272
  grad_weight = grad_weight * grad_output[0][0]
259
273
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
260
274
 
261
- return grad_input, grad_weight, None, grad_bias, None, None, None
275
+ return grad_input, grad_weight, None, grad_bias, None, None, None, None
262
276
 
263
277
  @staticmethod
264
278
  def chunk_forward(
@@ -268,6 +282,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
268
282
  bias=None,
269
283
  ignore_index=-100,
270
284
  compute_nll_loss=True,
285
+ chosen_nll_target_chunk=None,
271
286
  average_log_prob=True,
272
287
  ):
273
288
  len_chosen_chunk = target_chunk.shape[0] // 2
@@ -278,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
278
293
 
279
294
  chosen_nll_loss = 0.0
280
295
  if compute_nll_loss:
296
+ nll_labels = (
297
+ chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
298
+ )
281
299
  chosen_nll_loss = F.nll_loss(
282
300
  log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
283
- target_chunk[:len_chosen_chunk].view(-1),
301
+ nll_labels.view(-1),
284
302
  reduction="sum",
285
303
  ignore_index=ignore_index,
286
304
  )
@@ -324,6 +342,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
324
342
  ref_input_chunk=None,
325
343
  ref_weight=None,
326
344
  ref_bias=None,
345
+ full_nll_target=None,
346
+ chosen_nll_target_chunk=None,
327
347
  average_log_prob=True,
328
348
  **loss_kwargs,
329
349
  ):
@@ -343,6 +363,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
343
363
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
344
364
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
345
365
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
366
+ full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
367
+ chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
346
368
  average_log_prob (bool): Whether to average log probabilities or the sum.
347
369
  loss_kwargs (dict): Additional arguments for the loss function.
348
370
  """
@@ -359,9 +381,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
359
381
  bias=bias,
360
382
  ignore_index=ignore_index,
361
383
  compute_nll_loss=compute_nll_loss,
384
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
362
385
  average_log_prob=average_log_prob,
363
386
  )
364
- chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
387
+ if full_nll_target is not None:
388
+ chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
389
+ else:
390
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
391
+
365
392
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
366
393
  rejected_logits_mean = rejected_logits.sum() / (
367
394
  full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
@@ -372,9 +399,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
399
  (
373
400
  ref_chosen_logps,
374
401
  ref_rejected_logps,
375
- ref_chosen_logits,
376
- ref_rejected_logits,
377
- ref_chosen_nll_loss,
402
+ _,
403
+ _,
404
+ _,
378
405
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
379
406
  ref_input_chunk,
380
407
  ref_weight,
@@ -382,6 +409,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
382
409
  ref_bias,
383
410
  ignore_index=ignore_index,
384
411
  compute_nll_loss=False, # We don't need NLL loss for the reference model
412
+ chosen_nll_target_chunk=None,
385
413
  average_log_prob=average_log_prob,
386
414
  )
387
415
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
@@ -52,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
52
52
  ignore_index=-100,
53
53
  beta=0.1,
54
54
  compute_nll_loss=True,
55
+ nll_target=None,
55
56
  compiled=True,
56
57
  ):
57
58
  return LigerFusedLinearPreferenceBase.forward(
@@ -64,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
64
65
  ignore_index=ignore_index,
65
66
  beta=beta,
66
67
  compute_nll_loss=compute_nll_loss,
68
+ nll_target=nll_target,
67
69
  compiled=compiled,
68
70
  )
69
71
 
70
72
  @staticmethod
71
73
  def backward(ctx, *grad_output):
72
74
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
- return *grads, None, None, None, None
75
+ return *grads, None, None, None, None, None
74
76
 
75
77
 
76
78
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -96,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
96
98
  self.compute_nll_loss = compute_nll_loss
97
99
  self.compiled = compiled
98
100
 
99
- def forward(self, lin_weight, _input, target, bias=None):
101
+ def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
100
102
  return LigerFusedLinearORPOFunction.apply(
101
103
  _input,
102
104
  lin_weight,
@@ -105,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
105
107
  self.ignore_index,
106
108
  self.beta,
107
109
  self.compute_nll_loss,
110
+ nll_target,
108
111
  self.compiled,
109
112
  )
@@ -93,6 +93,13 @@ class LigerORPOTrainer(ORPOTrainer):
93
93
  if self.aux_loss_enabled:
94
94
  model_kwargs["output_router_logits"] = True
95
95
 
96
+ if self.is_encoder_decoder:
97
+ labels = concatenated_batch["concatenated_labels"].clone()
98
+ else:
99
+ labels = concatenated_batch["concatenated_input_ids"].clone()
100
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
101
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
102
+
96
103
  if isinstance(model, FullyShardedDataParallel):
97
104
  outputs = _FSDPForwardRedirection()(
98
105
  model,
@@ -114,15 +121,20 @@ class LigerORPOTrainer(ORPOTrainer):
114
121
 
115
122
  orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
116
123
 
117
- def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
118
- return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
124
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
125
+ return orpo_loss_fn(
126
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
127
+ )
119
128
 
120
129
  orpo_loss, aux_outputs = _FSDPForwardRedirection()(
121
130
  model,
122
131
  orpo_partial,
123
132
  model.lm_head,
124
- outputs.last_hidden_state,
125
- concatenated_batch["concatenated_labels"],
133
+ outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
134
+ concatenated_batch["concatenated_labels"][:, 1:]
135
+ if not self.is_encoder_decoder
136
+ else concatenated_batch["concatenated_labels"],
137
+ labels[:, 1:] if not self.is_encoder_decoder else labels,
126
138
  )
127
139
  # if aux_loss_enabled, add the aux_loss to the orpo_loss
128
140
  if self.aux_loss_enabled:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108102127
3
+ Version: 0.5.2.dev20250109023714
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -7,8 +7,8 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=eQCZmQ3xOL3jpZ7RhOfx_pqR9sNEX6RHx8DtIgyXEHc,16656
11
- liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
11
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  liger_kernel/ops/cross_entropy.py,sha256=SRzAF9Ek84pBVFy3wqQZs7AhRoorKRIgQ-Td_rtl1Kk,18606
@@ -55,12 +55,12 @@ liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UK
55
55
  liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
56
56
  liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
57
57
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
58
- liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
58
+ liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/METADATA,sha256=XHrJlebOzBW0f6tV-rb0iahG9LNI-f86Ar7s-upwoxo,21055
63
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20250108102127.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/METADATA,sha256=Qixggh8iZsja-nfhW4tPgwipEP_pCro3KxqZkBeyZ7s,21055
63
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/RECORD,,