liger-kernel-nightly 0.4.2.dev20241209195823__py3-none-any.whl → 0.4.2.dev20241209234352__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2401.08417
15
+
16
+ Formula:
17
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18
+
19
+ Where:
20
+ - π_θ(y|x): Policy (model) probability
21
+ - y_w: Chosen sequence
22
+ - y_l: Rejected sequence
23
+ - σ: Sigmoid function
24
+ - β: Temperature parameter
25
+ - E: Expected value over the dataset D
26
+ - D: Dataset of preferences
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
18
- beta (float): Weight for the odds ratio loss.
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the CPO loss
19
33
  """
20
34
  logits = beta * (chosen_logps - rejected_logps)
21
35
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -34,12 +48,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
34
48
  compute_nll_loss=True,
35
49
  compiled=True,
36
50
  ):
37
- """
38
- Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
39
- Handles both the forward and backward pass of the final linear layer with CPO loss.
40
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
41
- """
42
-
43
51
  return LigerFusedLinearPreferenceBase.forward(
44
52
  ctx,
45
53
  _input,
@@ -56,9 +64,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
56
64
 
57
65
  @staticmethod
58
66
  def backward(ctx, *grad_output):
59
- # Get gradients for _input, weight, bias, and target from the base class
60
67
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
61
- # Return these gradients, followed by None for the remaining inputs
62
68
  return *grads, None, None, None, None, None
63
69
 
64
70
 
@@ -18,14 +18,28 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
18
18
  beta=0.1,
19
19
  ):
20
20
  """
21
- Compute DPO loss (Direct Preference Optimization).
21
+ Paper: https://arxiv.org/pdf/2305.18290
22
+
23
+ Formula:
24
+ L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
25
+
26
+ Where:
27
+ - π(y|x): Policy (model) probability
28
+ - π_ref(y|x): Reference model probability
29
+ - y_w: Chosen sequence
30
+ - y_l: Rejected sequence
31
+ - β: Weight for the direct preference loss
32
+ - E: Expected value over the dataset
33
+
22
34
  Args:
23
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
24
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
25
- ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
26
- ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
27
- beta (float): Weight for the direct preference loss.
35
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
36
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
37
+ full_target: Non chunked full target tensor
38
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
39
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
40
+ beta: Weight for the direct preference loss
28
41
  """
42
+
29
43
  if ref_chosen_logps is None:
30
44
  ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
31
45
  if ref_rejected_logps is None:
@@ -53,10 +67,6 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
53
67
  compiled=True,
54
68
  use_ref_model=True,
55
69
  ):
56
- """
57
- Fused linear layer with DPO (Direct Preference Optimization) loss.
58
- Handles both the forward and backward pass of the final linear layer with DPO loss.
59
- """
60
70
  return LigerFusedLinearPreferenceBase.forward(
61
71
  ctx=ctx,
62
72
  _input=_input,
@@ -75,9 +85,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
75
85
 
76
86
  @staticmethod
77
87
  def backward(ctx, *grad_output):
78
- # Get gradients for _input, weight, bias, and target from the base class
79
88
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
80
- # Return these gradients, followed by None for the remaining inputs
81
89
  return *grads, None, None, None, None, None, None, None
82
90
 
83
91
 
@@ -8,159 +8,12 @@ from torch.nn import functional as F
8
8
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
9
 
10
10
  @abstractmethod
11
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
11
+ def preference_loss_fn(*args, **kwargs):
12
12
  """
13
- Compute preference loss.
14
- Args:
15
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
- beta (float): Weight for the odds ratio loss.
13
+ To be extended by subclasses.
18
14
  """
19
15
  raise NotImplementedError("Preference loss function must be implemented.")
20
16
 
21
- @staticmethod
22
- def chunk_forward(
23
- input_chunk,
24
- weight,
25
- target_chunk,
26
- bias=None,
27
- ignore_index=-100,
28
- compute_nll_loss=True,
29
- ):
30
- len_chosen_chunk = target_chunk.shape[0] // 2
31
- logits_chunk = input_chunk @ weight.t()
32
- if bias is not None:
33
- logits_chunk = logits_chunk + bias
34
- log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
35
-
36
- chosen_nll_loss = 0.0
37
- if compute_nll_loss:
38
- chosen_nll_loss = F.nll_loss(
39
- log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
40
- target_chunk[:len_chosen_chunk].view(-1),
41
- reduction="sum",
42
- ignore_index=ignore_index,
43
- )
44
-
45
- loss_mask = target_chunk != ignore_index
46
- label_chunk = torch.where(loss_mask, target_chunk, 0)
47
-
48
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
49
- -1
50
- )
51
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
52
-
53
- chosen_logps = average_log_prob[:len_chosen_chunk]
54
- rejected_logps = average_log_prob[len_chosen_chunk:]
55
-
56
- chosen_logits = logits_chunk[:len_chosen_chunk]
57
- rejected_logits = logits_chunk[len_chosen_chunk:]
58
-
59
- return (
60
- chosen_logps,
61
- rejected_logps,
62
- chosen_logits,
63
- rejected_logits,
64
- chosen_nll_loss,
65
- )
66
-
67
- @staticmethod
68
- def _compute_loss(
69
- input_chunk,
70
- weight,
71
- target_chunk,
72
- bias=None,
73
- preference_loss_fn=None,
74
- full_target=None,
75
- ignore_index=-100,
76
- alpha=1.0,
77
- beta=0.1,
78
- compute_nll_loss=True,
79
- use_ref_model=False,
80
- ref_weight=None,
81
- ref_bias=None,
82
- **loss_kwargs,
83
- ):
84
- """
85
- Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
86
- Args:
87
- preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
88
- input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
89
- weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
90
- target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
91
- bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
92
- full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
93
- ignore_index (int): Index to ignore for loss computation.
94
- alpha (float): Weight for the NLL loss.
95
- beta (float): Weight for the odds ratio loss.
96
- compute_nll_loss (bool): Whether to compute NLL loss.
97
- use_ref_model (bool): Whether to use a reference model for the alignment loss.
98
- ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
99
- ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
100
- loss_kwargs (dict): Additional arguments for the loss function.
101
- """
102
- (
103
- chosen_logps,
104
- rejected_logps,
105
- chosen_logits,
106
- rejected_logits,
107
- chosen_nll_loss,
108
- ) = LigerFusedLinearPreferenceBase.chunk_forward(
109
- input_chunk,
110
- weight,
111
- target_chunk,
112
- bias=bias,
113
- ignore_index=ignore_index,
114
- compute_nll_loss=compute_nll_loss,
115
- )
116
- chosen_nll_loss = (
117
- chosen_nll_loss
118
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
119
- )
120
- chosen_logits_mean = chosen_logits.sum() / (
121
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
122
- )
123
- rejected_logits_mean = rejected_logits.sum() / (
124
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
125
- )
126
-
127
- if use_ref_model:
128
- with torch.no_grad():
129
- (
130
- ref_chosen_logps,
131
- ref_rejected_logps,
132
- ref_chosen_logits,
133
- ref_rejected_logits,
134
- ref_chosen_nll_loss,
135
- ) = LigerFusedLinearPreferenceBase.chunk_forward(
136
- input_chunk,
137
- ref_weight,
138
- target_chunk,
139
- ref_bias,
140
- ignore_index=ignore_index,
141
- compute_nll_loss=False, # We don't need NLL loss for the reference model
142
- )
143
- loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
144
- loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
145
-
146
- preference_loss_outputs = preference_loss_fn(
147
- chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
148
- )
149
- if isinstance(preference_loss_outputs, tuple):
150
- preference_loss, *aux_outputs = preference_loss_outputs
151
- else:
152
- preference_loss, aux_outputs = preference_loss_outputs, []
153
-
154
- loss = alpha * chosen_nll_loss - preference_loss
155
- return_vars = (
156
- chosen_logps,
157
- rejected_logps,
158
- chosen_logits_mean,
159
- rejected_logits_mean,
160
- chosen_nll_loss,
161
- )
162
- return loss, (*return_vars, *aux_outputs)
163
-
164
17
  @staticmethod
165
18
  def forward(
166
19
  ctx,
@@ -176,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
176
29
  compute_nll_loss=True,
177
30
  compiled=True,
178
31
  use_ref_model=False,
32
+ # TODO: ref input
179
33
  ref_weight=None,
180
34
  ref_bias=None,
181
35
  **loss_kwargs,
@@ -184,6 +38,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
184
38
  Base class for fused linear layer with preference loss.
185
39
  Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
186
40
 
41
+ The mental model is:
42
+
43
+ forward()
44
+ ├── Loop over chunks
45
+ └── compute_loss()
46
+ ├── chunk_forward() # Compute logits and log probs
47
+ └── prefer_loss() # Calculate preference loss
48
+
187
49
  Args:
188
50
  _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
189
51
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
@@ -191,10 +53,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
191
53
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
192
54
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
193
55
  chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
194
- compute_nll_loss (bool): Whether to compute NLL loss.
195
56
  ignore_index (int): Index to ignore for loss computation.
196
57
  alpha (float): Weight for the NLL loss.
197
- beta (float): Weight for the odds ratio loss.
58
+ beta (float): Weight for the preference loss.
198
59
  compute_nll_loss (bool): Whether to compute NLL loss.
199
60
  compiled (bool): Whether to use torch compile for chunk accumulation.
200
61
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
@@ -205,11 +66,16 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
205
66
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
206
67
  CHUNK_SIZE = chunk_size
207
68
 
69
+ # Gradients to be accumulated
208
70
  grad_weight = torch.zeros_like(weight)
209
71
  grad_chosen_inputs = []
210
72
  grad_rejected_inputs = []
211
73
  grad_bias = torch.zeros_like(bias) if bias is not None else None
74
+
75
+ # Loss to be accumulated
212
76
  loss_acc = torch.zeros((), device=_input.device)
77
+
78
+ # Metrics to be recorded
213
79
  policy_chosen_logps = []
214
80
  policy_rejected_logps = []
215
81
  policy_chosen_logits_mean = torch.zeros((), device=_input.device)
@@ -217,7 +83,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
217
83
  policy_nll_loss = torch.zeros((), device=_input.device)
218
84
  aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
219
85
 
220
- loss_func_to_call = partial(
86
+ compute_loss = partial(
221
87
  LigerFusedLinearPreferenceBase._compute_loss,
222
88
  preference_loss_fn=loss_fn,
223
89
  ignore_index=ignore_index,
@@ -231,14 +97,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
231
97
  **loss_kwargs,
232
98
  )
233
99
 
234
- def accumulate_core(input_chunk, target_chunk):
100
+ def fused_fwd_bwd(input_chunk, target_chunk):
101
+ """
102
+ Fused forward and backward pass for a chunk of input and target.
103
+ """
235
104
  if bias is not None:
236
105
  return torch.func.grad_and_value(
237
- loss_func_to_call, argnums=(0, 1, 3), has_aux=True
106
+ compute_loss, argnums=(0, 1, 3), has_aux=True
238
107
  )(input_chunk, weight, target_chunk, bias)
239
108
  else:
240
109
  return torch.func.grad_and_value(
241
- loss_func_to_call, argnums=(0, 1), has_aux=True
110
+ compute_loss, argnums=(0, 1), has_aux=True
242
111
  )(input_chunk, weight, target_chunk)
243
112
 
244
113
  def accumulate_chunk(input_chunk, target_chunk):
@@ -253,7 +122,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
253
122
  chunk_nll_loss,
254
123
  *aux_outputs,
255
124
  ),
256
- ) = accumulate_core(input_chunk, target_chunk)
125
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
257
126
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
258
127
  else:
259
128
  (chunk_grad_input, chunk_grad_weight), (
@@ -266,16 +135,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
266
135
  chunk_nll_loss,
267
136
  *aux_outputs,
268
137
  ),
269
- ) = accumulate_core(input_chunk, target_chunk)
138
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
270
139
 
140
+ # Accumulate gradients
271
141
  grad_weight.add_(chunk_grad_weight)
142
+ grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
143
+ grad_rejected_inputs.append(
144
+ chunk_grad_input[chosen_target_chunk.shape[0] :]
145
+ )
146
+
147
+ # Accumulate loss
272
148
  loss_acc.add_(chunk_loss)
149
+
150
+ # Accumulate metrics
273
151
  policy_chosen_logps.append(chunk_chosen_logps)
274
152
  policy_rejected_logps.append(chunk_rejected_logps)
275
153
  policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
276
154
  policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
277
155
  policy_nll_loss.add_(chunk_nll_loss)
278
156
 
157
+ # aux_outputs
279
158
  # Initialize storage for aux_outputs
280
159
  if len(aggregated_aux_outputs) == 0:
281
160
  for aux in aux_outputs:
@@ -293,10 +172,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
293
172
  else:
294
173
  aggregated_aux_outputs[i].append(aux)
295
174
 
296
- return chunk_grad_input
297
-
298
175
  if compiled:
299
- accumulate_core = torch.compile(accumulate_core)
176
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
300
177
 
301
178
  len_chosen = target.shape[0] // 2
302
179
  chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
@@ -327,10 +204,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
327
204
  torch._dynamo.mark_dynamic(target, 1)
328
205
 
329
206
  # accumulate loss, gradients, and metrics
330
- grad_input = accumulate_chunk(input_chunk, target_chunk)
331
-
332
- grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
333
- grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
207
+ accumulate_chunk(input_chunk, target_chunk)
334
208
 
335
209
  # combine grad_chosen_inputs and grad_rejected_inputs
336
210
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -367,3 +241,146 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
367
241
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
368
242
 
369
243
  return grad_input, grad_weight, None, grad_bias, None, None, None
244
+
245
+ @staticmethod
246
+ def chunk_forward(
247
+ input_chunk,
248
+ weight,
249
+ target_chunk,
250
+ bias=None,
251
+ ignore_index=-100,
252
+ compute_nll_loss=True,
253
+ ):
254
+ len_chosen_chunk = target_chunk.shape[0] // 2
255
+ logits_chunk = input_chunk @ weight.t()
256
+ if bias is not None:
257
+ logits_chunk = logits_chunk + bias
258
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
259
+
260
+ chosen_nll_loss = 0.0
261
+ if compute_nll_loss:
262
+ chosen_nll_loss = F.nll_loss(
263
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
264
+ target_chunk[:len_chosen_chunk].view(-1),
265
+ reduction="sum",
266
+ ignore_index=ignore_index,
267
+ )
268
+
269
+ loss_mask = target_chunk != ignore_index
270
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
271
+
272
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
273
+ -1
274
+ )
275
+ average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
276
+
277
+ chosen_logps = average_log_prob[:len_chosen_chunk]
278
+ rejected_logps = average_log_prob[len_chosen_chunk:]
279
+
280
+ chosen_logits = logits_chunk[:len_chosen_chunk]
281
+ rejected_logits = logits_chunk[len_chosen_chunk:]
282
+
283
+ return (
284
+ chosen_logps,
285
+ rejected_logps,
286
+ chosen_logits,
287
+ rejected_logits,
288
+ chosen_nll_loss,
289
+ )
290
+
291
+ @staticmethod
292
+ def _compute_loss(
293
+ input_chunk,
294
+ weight,
295
+ target_chunk,
296
+ bias=None,
297
+ preference_loss_fn=None,
298
+ full_target=None,
299
+ ignore_index=-100,
300
+ alpha=1.0,
301
+ beta=0.1,
302
+ compute_nll_loss=True,
303
+ use_ref_model=False,
304
+ ref_weight=None,
305
+ ref_bias=None,
306
+ **loss_kwargs,
307
+ ):
308
+ """
309
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
310
+ Args:
311
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
312
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
313
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
314
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
315
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
316
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
317
+ ignore_index (int): Index to ignore for loss computation.
318
+ alpha (float): Weight for the NLL loss.
319
+ beta (float): Weight for the preference loss.
320
+ compute_nll_loss (bool): Whether to compute NLL loss.
321
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
322
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
+ loss_kwargs (dict): Additional arguments for the loss function.
325
+ """
326
+ (
327
+ chosen_logps,
328
+ rejected_logps,
329
+ chosen_logits,
330
+ rejected_logits,
331
+ chosen_nll_loss,
332
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
333
+ input_chunk,
334
+ weight,
335
+ target_chunk,
336
+ bias=bias,
337
+ ignore_index=ignore_index,
338
+ compute_nll_loss=compute_nll_loss,
339
+ )
340
+ chosen_nll_loss = (
341
+ chosen_nll_loss
342
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
343
+ )
344
+ chosen_logits_mean = chosen_logits.sum() / (
345
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
346
+ )
347
+ rejected_logits_mean = rejected_logits.sum() / (
348
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
349
+ )
350
+
351
+ if use_ref_model:
352
+ with torch.no_grad():
353
+ (
354
+ ref_chosen_logps,
355
+ ref_rejected_logps,
356
+ ref_chosen_logits,
357
+ ref_rejected_logits,
358
+ ref_chosen_nll_loss,
359
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
+ input_chunk,
361
+ ref_weight,
362
+ target_chunk,
363
+ ref_bias,
364
+ ignore_index=ignore_index,
365
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
366
+ )
367
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
368
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
369
+
370
+ preference_loss_outputs = preference_loss_fn(
371
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
372
+ )
373
+ if isinstance(preference_loss_outputs, tuple):
374
+ preference_loss, *aux_outputs = preference_loss_outputs
375
+ else:
376
+ preference_loss, aux_outputs = preference_loss_outputs, []
377
+
378
+ loss = alpha * chosen_nll_loss - preference_loss
379
+ return_vars = (
380
+ chosen_logps,
381
+ rejected_logps,
382
+ chosen_logits_mean,
383
+ rejected_logits_mean,
384
+ chosen_nll_loss,
385
+ )
386
+ return loss, (*return_vars, *aux_outputs)
@@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2403.07691
15
+
16
+ Formula:
17
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19
+
20
+ Where:
21
+ - P_θ(y|x): Policy (model) probability
22
+ - y_w: Chosen sequence
23
+ - y_l: Rejected sequence
24
+ - σ: Sigmoid function
25
+ - β: Weight for the odds ratio loss
26
+ - odds_θ: Odds function for the policy
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
18
32
  beta (float): Weight for the odds ratio loss.
19
33
  """
20
34
  log_odds = (chosen_logps - rejected_logps) - (
@@ -44,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
44
58
  compute_nll_loss=True,
45
59
  compiled=True,
46
60
  ):
47
- """
48
- Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
49
- Handles both the forward and backward pass of the final linear layer with ORPO loss.
50
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
51
- """
52
-
53
61
  return LigerFusedLinearPreferenceBase.forward(
54
62
  ctx=ctx,
55
63
  _input=_input,
@@ -65,9 +73,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
65
73
 
66
74
  @staticmethod
67
75
  def backward(ctx, *grad_output):
68
- # Get gradients for _input, weight, bias, and target from the base class
69
76
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
70
- # Return these gradients, followed by None for the remaining inputs
71
77
  return *grads, None, None, None, None
72
78
 
73
79
 
@@ -13,12 +13,26 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
13
13
  chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
14
  ):
15
15
  """
16
- Compute odds-ratio loss.
16
+ Paper: https://arxiv.org/pdf/2405.14734
17
+
18
+ Formula:
19
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20
+
21
+ Where:
22
+ - π_θ(y|x): Policy (model) probability
23
+ - y_w: Chosen sequence
24
+ - y_l: Rejected sequence
25
+ - |y_w|, |y_l|: Sequence lengths
26
+ - σ: Sigmoid function
27
+ - β: beta weight
28
+ - γ: gemma margin term
29
+
17
30
  Args:
18
31
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
19
32
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
20
- beta (float): Weight for the odds ratio loss.
21
- gamma (float): The simpo gamma, margin term.
33
+ full_target: Non chunked full target tensor
34
+ beta (float): beta weight
35
+ gamma (float): gemma margin term
22
36
  """
23
37
  logits = beta * (chosen_logps - rejected_logps) - gamma
24
38
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -38,12 +52,6 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
38
52
  compiled=True,
39
53
  gamma=0.5,
40
54
  ):
41
- """
42
- Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
43
- Handles both the forward and backward pass of the final linear layer with SimPO loss.
44
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
45
- """
46
-
47
55
  return LigerFusedLinearPreferenceBase.forward(
48
56
  ctx,
49
57
  _input,
@@ -61,9 +69,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
69
 
62
70
  @staticmethod
63
71
  def backward(ctx, *grad_output):
64
- # Get gradients for _input, weight, bias, and target from the base class
65
72
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
66
- # Return these gradients, followed by None for the remaining inputs
67
73
  return *grads, None, None, None, None, None, None
68
74
 
69
75
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209195823
3
+ Version: 0.4.2.dev20241209234352
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,13 +2,13 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=1ETxx6HW4bKMK5aa5xaFzEmx0Ibc_kNryL_gXBVyyrI,1374
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
5
- liger_kernel/chunked_loss/cpo_loss.py,sha256=P20txjErLCSfSfToFT8pnuVPqFU4Bbybt3zRXfGEV-0,3122
6
- liger_kernel/chunked_loss/dpo_loss.py,sha256=NZyM4ju56MBVrUTI_7-jGMx5pWWDYzwx7ALoMj1G8Ec,4276
5
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
6
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
7
7
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
8
8
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
9
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=_4MDZMzrNNgm91c6qdLEuXG1M8HyglZioiufv5opJOI,14881
10
- liger_kernel/chunked_loss/orpo_loss.py,sha256=GGwc3pLGGJzb_P_C7IogcA1EfdAcM1uktfKPmI1z2jk,3523
11
- liger_kernel/chunked_loss/simpo_loss.py,sha256=FtURWbXGjoAKyiVYF7fkMv8Us7uk3UrSg21pWOFk11Y,3385
9
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
10
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
11
+ liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
12
12
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  liger_kernel/ops/cross_entropy.py,sha256=VqaYB9Zirc51eZ28OmjEZRrrV9UysRjS_vhIftB9sKo,15753
14
14
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
@@ -56,9 +56,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
56
56
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
57
57
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
58
58
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
59
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
60
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/METADATA,sha256=rdhqAHF-DhOwy_DKk5SVEAC65LcW-IeyMY5QcYRUwSg,22801
61
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
62
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
63
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
64
- liger_kernel_nightly-0.4.2.dev20241209195823.dist-info/RECORD,,
59
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
60
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/METADATA,sha256=DXgBwRWN509ykIXn_83UuDRiwhZ-1RQPv4ubuieBXBA,22801
61
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
62
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
63
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
64
+ liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/RECORD,,