liger-kernel-nightly 0.4.2.dev20241203232039__py3-none-any.whl → 0.4.2.dev20241206180928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
14
  Compute odds-ratio loss.
15
15
  Args:
@@ -18,7 +18,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
18
18
  beta (float): Weight for the odds ratio loss.
19
19
  """
20
20
  logits = beta * (chosen_logps - rejected_logps)
21
- loss = F.logsigmoid(logits).mean()
21
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
22
22
  return loss
23
23
 
24
24
  @staticmethod
@@ -55,7 +55,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
55
55
  )
56
56
 
57
57
  @staticmethod
58
- def backward(ctx, grad_output):
58
+ def backward(ctx, *grad_output):
59
59
  # Get gradients for _input, weight, bias, and target from the base class
60
60
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
61
61
  # Return these gradients, followed by None for the remaining inputs
@@ -12,6 +12,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
12
12
  def preference_loss_fn(
13
13
  chosen_logps,
14
14
  rejected_logps,
15
+ full_target,
15
16
  ref_chosen_logps=None,
16
17
  ref_rejected_logps=None,
17
18
  beta=0.1,
@@ -34,8 +35,8 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
34
35
  rejected_logratios = rejected_logps - ref_rejected_logps
35
36
 
36
37
  logits_diff = beta * (chosen_logratios - rejected_logratios)
37
- losses = -F.logsigmoid(logits_diff)
38
- return losses.sum()
38
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
39
+ return loss
39
40
 
40
41
  @staticmethod
41
42
  def forward(
@@ -73,7 +74,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
73
74
  )
74
75
 
75
76
  @staticmethod
76
- def backward(ctx, grad_output):
77
+ def backward(ctx, *grad_output):
77
78
  # Get gradients for _input, weight, bias, and target from the base class
78
79
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
79
80
  # Return these gradients, followed by None for the remaining inputs
@@ -52,7 +52,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
52
52
 
53
53
  chosen_logps = average_log_prob[:len_chosen_chunk]
54
54
  rejected_logps = average_log_prob[len_chosen_chunk:]
55
- return chosen_logps, rejected_logps, chosen_nll_loss
55
+
56
+ chosen_logits = logits_chunk[:len_chosen_chunk]
57
+ rejected_logits = logits_chunk[len_chosen_chunk:]
58
+
59
+ return (
60
+ chosen_logps,
61
+ rejected_logps,
62
+ chosen_logits,
63
+ rejected_logits,
64
+ chosen_nll_loss,
65
+ )
56
66
 
57
67
  @staticmethod
58
68
  def forward(
@@ -103,6 +113,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
103
113
  grad_rejected_inputs = []
104
114
  grad_bias = torch.zeros_like(bias) if bias is not None else None
105
115
  loss_acc = torch.zeros((), device=_input.device)
116
+ policy_chosen_logps = []
117
+ policy_rejected_logps = []
118
+ policy_chosen_logits_mean = torch.zeros((), device=_input.device)
119
+ policy_rejected_logits_mean = torch.zeros((), device=_input.device)
120
+ policy_nll_loss = torch.zeros((), device=_input.device)
121
+ aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
106
122
 
107
123
  loss_func_to_call = partial(
108
124
  LigerFusedLinearPreferenceBase._compute_loss,
@@ -118,32 +134,72 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
118
134
  **loss_kwargs,
119
135
  )
120
136
 
137
+ def accumulate_helper(input_chunk, target_chunk):
138
+ if bias is not None:
139
+ return torch.func.grad_and_value(
140
+ loss_func_to_call, argnums=(0, 1, 3), has_aux=True
141
+ )(input_chunk, weight, target_chunk, bias)
142
+ else:
143
+ return torch.func.grad_and_value(
144
+ loss_func_to_call, argnums=(0, 1), has_aux=True
145
+ )(input_chunk, weight, target_chunk)
146
+
121
147
  def accumulate_chunk(input_chunk, target_chunk):
122
148
  if bias is not None:
123
149
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
124
150
  chunk_loss,
125
- (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
126
- ) = torch.func.grad_and_value(
127
- loss_func_to_call, argnums=(0, 1, 3), has_aux=True
128
- )(
129
- input_chunk, weight, target_chunk, bias
130
- )
131
- grad_bias.add_(chunk_grad_bias)
151
+ (
152
+ chunk_chosen_logps,
153
+ chunk_rejected_logps,
154
+ chunk_chosen_logits_mean,
155
+ chunk_rejected_logits_mean,
156
+ chunk_nll_loss,
157
+ *aux_outputs,
158
+ ),
159
+ ) = accumulate_helper(input_chunk, target_chunk)
160
+ grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
132
161
  else:
133
162
  (chunk_grad_input, chunk_grad_weight), (
134
163
  chunk_loss,
135
- (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
136
- ) = torch.func.grad_and_value(
137
- loss_func_to_call, argnums=(0, 1), has_aux=True
138
- )(
139
- input_chunk, weight, target_chunk
140
- )
164
+ (
165
+ chunk_chosen_logps,
166
+ chunk_rejected_logps,
167
+ chunk_chosen_logits_mean,
168
+ chunk_rejected_logits_mean,
169
+ chunk_nll_loss,
170
+ *aux_outputs,
171
+ ),
172
+ ) = accumulate_helper(input_chunk, target_chunk)
173
+
141
174
  grad_weight.add_(chunk_grad_weight)
142
175
  loss_acc.add_(chunk_loss)
176
+ policy_chosen_logps.append(chunk_chosen_logps)
177
+ policy_rejected_logps.append(chunk_rejected_logps)
178
+ policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
179
+ policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
180
+ policy_nll_loss.add_(chunk_nll_loss)
181
+
182
+ # Initialize storage for aux_outputs
183
+ if len(aggregated_aux_outputs) == 0:
184
+ for aux in aux_outputs:
185
+ if aux.ndim == 0:
186
+ aggregated_aux_outputs.append(
187
+ torch.zeros((), device=aux.device)
188
+ )
189
+ else:
190
+ aggregated_aux_outputs.append([])
191
+
192
+ # Process each aux_output
193
+ for i, aux in enumerate(aux_outputs):
194
+ if aux.ndim == 0:
195
+ aggregated_aux_outputs[i].add_(aux)
196
+ else:
197
+ aggregated_aux_outputs[i].append(aux)
198
+
143
199
  return chunk_grad_input
144
200
 
145
201
  if compiled:
146
- accumulate_chunk = torch.compile(accumulate_chunk)
202
+ accumulate_helper = torch.compile(accumulate_helper)
147
203
 
148
204
  len_chosen = target.shape[0] // 2
149
205
  chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
@@ -168,6 +224,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
168
224
  [chosen_target_chunk, rejected_target_chunk], dim=0
169
225
  )
170
226
 
227
+ # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
228
+ torch._dynamo.mark_dynamic(input_chunk, 1)
229
+ torch._dynamo.mark_dynamic(target_chunk, 1)
230
+ torch._dynamo.mark_dynamic(target, 1)
231
+
232
+ # accumulate loss, gradients, and metrics
171
233
  grad_input = accumulate_chunk(input_chunk, target_chunk)
172
234
 
173
235
  grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
@@ -175,21 +237,37 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
175
237
 
176
238
  # combine grad_chosen_inputs and grad_rejected_inputs
177
239
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
240
+ policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
241
+ policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
242
+
243
+ # Aggregate aux outputs lists into tensors
244
+ for i, aux in enumerate(aggregated_aux_outputs):
245
+ if isinstance(aux, list):
246
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
178
247
 
179
248
  ctx.save_for_backward(
180
249
  torch.cat(grad_inputs, dim=0),
181
250
  grad_weight,
182
251
  grad_bias,
183
252
  )
184
- return loss_acc
253
+ return_vars = (
254
+ policy_chosen_logps,
255
+ policy_rejected_logps,
256
+ policy_chosen_logits_mean,
257
+ policy_rejected_logits_mean,
258
+ policy_nll_loss,
259
+ )
260
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
185
261
 
186
262
  @staticmethod
187
- def backward(ctx, grad_output):
263
+ def backward(ctx, *grad_output):
188
264
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
189
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
190
- grad_input = grad_input * grad_output
191
- grad_weight = grad_weight * grad_output
192
- grad_bias = grad_bias * grad_output if grad_bias is not None else None
265
+ if torch.ne(
266
+ grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
267
+ ):
268
+ grad_input = grad_input * grad_output[0][0]
269
+ grad_weight = grad_weight * grad_output[0][0]
270
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
193
271
 
194
272
  return grad_input, grad_weight, None, grad_bias, None, None, None
195
273
 
@@ -228,40 +306,64 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
228
306
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
229
307
  loss_kwargs (dict): Additional arguments for the loss function.
230
308
  """
231
- chosen_logps, rejected_logps, chosen_nll_loss = (
232
- LigerFusedLinearPreferenceBase.chunk_forward(
233
- input_chunk,
234
- weight,
235
- target_chunk,
236
- bias=bias,
237
- ignore_index=ignore_index,
238
- compute_nll_loss=compute_nll_loss,
239
- )
309
+ (
310
+ chosen_logps,
311
+ rejected_logps,
312
+ chosen_logits,
313
+ rejected_logits,
314
+ chosen_nll_loss,
315
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
316
+ input_chunk,
317
+ weight,
318
+ target_chunk,
319
+ bias=bias,
320
+ ignore_index=ignore_index,
321
+ compute_nll_loss=compute_nll_loss,
240
322
  )
241
323
  chosen_nll_loss = (
242
324
  chosen_nll_loss
243
325
  / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
244
326
  )
327
+ chosen_logits_mean = chosen_logits.sum() / (
328
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
329
+ )
330
+ rejected_logits_mean = rejected_logits.sum() / (
331
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
332
+ )
245
333
 
246
334
  if use_ref_model:
247
335
  with torch.no_grad():
248
- ref_chosen_logps, ref_rejected_logps, _ = (
249
- LigerFusedLinearPreferenceBase.chunk_forward(
250
- input_chunk,
251
- ref_weight,
252
- target_chunk,
253
- ref_bias,
254
- ignore_index=ignore_index,
255
- compute_nll_loss=False,
256
- )
336
+ (
337
+ ref_chosen_logps,
338
+ ref_rejected_logps,
339
+ ref_chosen_logits,
340
+ ref_rejected_logits,
341
+ ref_chosen_nll_loss,
342
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
343
+ input_chunk,
344
+ ref_weight,
345
+ target_chunk,
346
+ ref_bias,
347
+ ignore_index=ignore_index,
348
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
257
349
  )
258
350
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
259
351
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
260
352
 
261
- alignment_loss = preference_loss_fn(
262
- chosen_logps, rejected_logps, beta=beta, **loss_kwargs
353
+ preference_loss_outputs = preference_loss_fn(
354
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
263
355
  )
264
- alignment_loss = alignment_loss / (full_target.shape[0] // 2)
356
+ if isinstance(preference_loss_outputs, tuple):
357
+ preference_loss, *aux_outputs = preference_loss_outputs
358
+ else:
359
+ preference_loss, aux_outputs = preference_loss_outputs, []
265
360
 
266
- loss = alpha * chosen_nll_loss - alignment_loss
267
- return loss, (alignment_loss, chosen_logps, rejected_logps)
361
+ loss = alpha * chosen_nll_loss - preference_loss
362
+ return_vars = (
363
+ chosen_logps,
364
+ rejected_logps,
365
+ chosen_logits_mean,
366
+ rejected_logits_mean,
367
+ chosen_nll_loss,
368
+ )
369
+ return loss, (*return_vars, *aux_outputs)
@@ -9,7 +9,7 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
14
  Compute odds-ratio loss.
15
15
  Args:
@@ -22,7 +22,15 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
22
22
  - torch.log1p(-torch.exp(rejected_logps))
23
23
  )
24
24
  ratio = F.logsigmoid(log_odds)
25
- return beta * ratio.sum()
25
+ loss = beta * ratio.sum() / (full_target.shape[0] // 2)
26
+
27
+ chosen_rewards = beta * chosen_logps
28
+ rejected_rewards = beta * rejected_logps
29
+
30
+ log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
31
+ log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
32
+
33
+ return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
26
34
 
27
35
  @staticmethod
28
36
  def forward(
@@ -56,7 +64,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
56
64
  )
57
65
 
58
66
  @staticmethod
59
- def backward(ctx, grad_output):
67
+ def backward(ctx, *grad_output):
60
68
  # Get gradients for _input, weight, bias, and target from the base class
61
69
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
62
70
  # Return these gradients, followed by None for the remaining inputs
@@ -9,7 +9,9 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5):
12
+ def preference_loss_fn(
13
+ chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
+ ):
13
15
  """
14
16
  Compute odds-ratio loss.
15
17
  Args:
@@ -19,7 +21,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
19
21
  gamma (float): The simpo gamma, margin term.
20
22
  """
21
23
  logits = beta * (chosen_logps - rejected_logps) - gamma
22
- loss = F.logsigmoid(logits).mean()
24
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
23
25
  return loss
24
26
 
25
27
  @staticmethod
@@ -58,7 +60,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
58
60
  )
59
61
 
60
62
  @staticmethod
61
- def backward(ctx, grad_output):
63
+ def backward(ctx, *grad_output):
62
64
  # Get gradients for _input, weight, bias, and target from the base class
63
65
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
64
66
  # Return these gradients, followed by None for the remaining inputs
@@ -22,6 +22,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
22
22
  apply_liger_kernel_to_qwen2,
23
23
  apply_liger_kernel_to_qwen2_vl,
24
24
  )
25
+ from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
25
26
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
27
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
28
  from liger_kernel.transformers.swiglu import ( # noqa: F401
@@ -0,0 +1,171 @@
1
+ from typing import Any, Callable, Dict, List, Literal, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributed.fsdp import FullyShardedDataParallel
6
+ from trl.trainer import ORPOTrainer
7
+
8
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
9
+
10
+
11
+ class _FSDPForwardRedirection:
12
+ """
13
+ Modified based on
14
+ https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
15
+ Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
16
+ post-forward can be properly executed around the method call.
17
+ This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
18
+ the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
19
+ GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
20
+ will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
21
+ the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
22
+ its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
23
+ the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
24
+ """
25
+
26
+ def __call__(
27
+ self,
28
+ wrapper_module: FullyShardedDataParallel,
29
+ method: Callable,
30
+ *args: Any,
31
+ **kwargs: Any,
32
+ ):
33
+ """Reroutes a method call through the `wrapper_module`'s `forward` method.
34
+ Args:
35
+ wrapper_module: The module that has `original_module` wrapped.
36
+ original_module: The module that was wrapped inside `wrapper_module`.
37
+ method_name: The name of the method that should be called on the `original_module` after inputs get
38
+ redirected through the `wrapper_module`'s `forward` method.
39
+ *args: The positional arguments to the method `method_name`. They will get passed to a patched
40
+ `forward` method instead.
41
+ **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
42
+ `forward` method instead.
43
+ """
44
+ assert isinstance(wrapper_module, FullyShardedDataParallel)
45
+ original_module = wrapper_module._fsdp_wrapped_module
46
+ original_forward = original_module.forward
47
+
48
+ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
49
+ # Unpatch ourselves immediately before calling the method `method_name`
50
+ # because itself may want to call the real `forward`
51
+ original_module.forward = original_forward # type: ignore[method-assign]
52
+ # Call the actual method e.g. `.training_step(...)`
53
+ out = method(*_args, **_kwargs)
54
+ return out
55
+
56
+ # Patch the original_module's forward so we can redirect the arguments back to the real method
57
+ original_module.forward = wrapped_forward # type: ignore[method-assign]
58
+ wrapper_output = wrapper_module(*args, **kwargs)
59
+ return wrapper_output
60
+
61
+
62
+ class LigerORPOTrainer(ORPOTrainer):
63
+ def concatenated_forward(
64
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
65
+ ) -> Tuple[
66
+ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
67
+ ]:
68
+ """
69
+ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
70
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
71
+ """
72
+ concatenated_batch = self.concatenated_inputs(
73
+ batch,
74
+ is_encoder_decoder=self.is_encoder_decoder,
75
+ label_pad_token_id=self.label_pad_token_id,
76
+ padding_value=self.padding_value,
77
+ device=self.accelerator.device,
78
+ )
79
+ # if self.accelerator.is_main_process:
80
+ # import pdb; pdb.set_trace()
81
+ # torch.distributed.barrier()
82
+ model_kwargs = (
83
+ {
84
+ "decoder_input_ids": self._shift_right(
85
+ concatenated_batch["concatenated_labels"]
86
+ ),
87
+ }
88
+ if self.is_encoder_decoder
89
+ else {}
90
+ )
91
+
92
+ if self.aux_loss_enabled:
93
+ model_kwargs["output_router_logits"] = True
94
+
95
+ if isinstance(model, FullyShardedDataParallel):
96
+ outputs = _FSDPForwardRedirection()(
97
+ model,
98
+ model._fsdp_wrapped_module.model,
99
+ concatenated_batch["concatenated_input_ids"],
100
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
101
+ use_cache=False,
102
+ **model_kwargs,
103
+ )
104
+ else:
105
+ if isinstance(model, torch.nn.DataParallel):
106
+ model = model.module
107
+ outputs = model.model(
108
+ concatenated_batch["concatenated_input_ids"],
109
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
110
+ use_cache=False,
111
+ **model_kwargs,
112
+ )
113
+
114
+ orpo_loss_fn = LigerFusedLinearORPOLoss(
115
+ ignore_index=self.label_pad_token_id, beta=self.beta
116
+ )
117
+
118
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
119
+ return orpo_loss_fn(
120
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
121
+ )
122
+
123
+ orpo_loss, aux_outputs = _FSDPForwardRedirection()(
124
+ model,
125
+ orpo_partial,
126
+ model.lm_head,
127
+ outputs.last_hidden_state,
128
+ concatenated_batch["concatenated_labels"],
129
+ )
130
+ return orpo_loss, aux_outputs
131
+
132
+ def get_batch_loss_metrics(
133
+ self,
134
+ model,
135
+ batch: Dict[str, Union[List, torch.LongTensor]],
136
+ train_eval: Literal["train", "eval"] = "train",
137
+ ):
138
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
139
+ metrics = {}
140
+ loss, aux_outputs = self.concatenated_forward(model, batch)
141
+ (
142
+ policy_chosen_logps,
143
+ policy_rejected_logps,
144
+ policy_chosen_logits,
145
+ policy_rejected_logits,
146
+ policy_nll_loss,
147
+ ) = aux_outputs[:5]
148
+
149
+ # return loss, metrics
150
+ chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
151
+ 5:
152
+ ]
153
+
154
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
155
+
156
+ prefix = "eval_" if train_eval == "eval" else ""
157
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
158
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
159
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
160
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
161
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
162
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
163
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
164
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
165
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
166
+ metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
167
+ metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
168
+ for k, v in metrics.items():
169
+ metrics[k] = v.item()
170
+
171
+ return loss, metrics
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241203232039
3
+ Version: 0.4.2.dev20241206180928
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -34,6 +34,7 @@ Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: transformers>=4.44.2; extra == "dev"
37
+ Requires-Dist: trl>=0.11.0; extra == "dev"
37
38
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
39
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
40
  Requires-Dist: black>=24.4.2; extra == "dev"
@@ -2,12 +2,12 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,1132
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
5
- liger_kernel/chunked_loss/cpo_loss.py,sha256=H2L6mNtU8RMJ17u4aMZ9FHEfBvg1Z_hliY5-jZxiDBM,3079
6
- liger_kernel/chunked_loss/dpo_loss.py,sha256=XcCGLVmTVdEX30q41XRXXK_c-MSumVJ-l4tQwobUv2w,4228
5
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=P20txjErLCSfSfToFT8pnuVPqFU4Bbybt3zRXfGEV-0,3122
6
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=NZyM4ju56MBVrUTI_7-jGMx5pWWDYzwx7ALoMj1G8Ec,4276
7
7
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
8
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=nkEpNWTHh5GmlnHOnGx5ifjigbOuUhc3hRy7RehXDbE,10838
9
- liger_kernel/chunked_loss/orpo_loss.py,sha256=DZ-_hm1twllBWujEV4M4-VDBkxMDBvoGqMGe-aGP1hA,3147
10
- liger_kernel/chunked_loss/simpo_loss.py,sha256=Jpl_U6DfxlzyHnlKN2i05K0vwz-ouiTmxlLGb439FwY,3328
8
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=nod7GcsTBV_L6RGRd55meB2D5KWzETVSnIz6xFbjVCc,14891
9
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=GGwc3pLGGJzb_P_C7IogcA1EfdAcM1uktfKPmI1z2jk,3523
10
+ liger_kernel/chunked_loss/simpo_loss.py,sha256=FtURWbXGjoAKyiVYF7fkMv8Us7uk3UrSg21pWOFk11Y,3385
11
11
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  liger_kernel/ops/cross_entropy.py,sha256=VqaYB9Zirc51eZ28OmjEZRrrV9UysRjS_vhIftB9sKo,15753
13
13
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
@@ -24,7 +24,7 @@ liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,29
24
24
  liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
25
25
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
26
26
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
27
- liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
27
+ liger_kernel/transformers/__init__.py,sha256=P5JR3fI-znhG92nRrFS2j0TIJTLhP-xD5dvEy4HP9ik,1418
28
28
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
29
29
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
30
30
  liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
@@ -36,6 +36,7 @@ liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_
36
36
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
37
37
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
38
38
  liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
39
+ liger_kernel/transformers/orpo_trainer.py,sha256=mC8ePS-Oq-BrdM0lKpgSBLuYLqYsWxH_4Q2RnDthz5M,7643
39
40
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=SfSQVwOe7ArrVfpmIdfZrdzCxmcj7V-YQp9zDu17-ao,1043
40
41
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
41
42
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
@@ -54,9 +55,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
54
55
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
55
56
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
56
57
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
57
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
58
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/METADATA,sha256=GD7sOJhLqOExLzto7Qhlp554vRb1JDkM_zULsZ8HhYU,21897
59
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
60
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
61
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
62
- liger_kernel_nightly-0.4.2.dev20241203232039.dist-info/RECORD,,
58
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
59
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/METADATA,sha256=WAAJkbzUZII072MIUuE8_72lDZNPoRac1suRYzGTrsg,21940
60
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
61
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
62
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
63
+ liger_kernel_nightly-0.4.2.dev20241206180928.dist-info/RECORD,,