liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +51 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +30 -4
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
  5. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  6. liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  7. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  8. liger_kernel/chunked_loss/grpo_loss.py +137 -61
  9. liger_kernel/chunked_loss/jsd_loss.py +43 -13
  10. liger_kernel/chunked_loss/kto_loss.py +50 -12
  11. liger_kernel/chunked_loss/orpo_loss.py +37 -5
  12. liger_kernel/chunked_loss/simpo_loss.py +47 -11
  13. liger_kernel/ops/cross_entropy.py +7 -2
  14. liger_kernel/ops/dyt.py +225 -0
  15. liger_kernel/ops/fused_linear_jsd.py +2 -1
  16. liger_kernel/ops/jsd.py +30 -11
  17. liger_kernel/ops/kl_div.py +2 -2
  18. liger_kernel/transformers/__init__.py +4 -0
  19. liger_kernel/transformers/dyt.py +20 -0
  20. liger_kernel/transformers/functional.py +5 -0
  21. liger_kernel/transformers/model/gemma.py +8 -16
  22. liger_kernel/transformers/model/gemma2.py +7 -16
  23. liger_kernel/transformers/model/llama.py +8 -15
  24. liger_kernel/transformers/model/llava.py +369 -0
  25. liger_kernel/transformers/model/loss_utils.py +57 -0
  26. liger_kernel/transformers/model/mistral.py +9 -10
  27. liger_kernel/transformers/model/mixtral.py +8 -15
  28. liger_kernel/transformers/model/mllama.py +8 -15
  29. liger_kernel/transformers/model/olmo2.py +8 -16
  30. liger_kernel/transformers/model/paligemma.py +397 -0
  31. liger_kernel/transformers/model/phi3.py +8 -15
  32. liger_kernel/transformers/model/qwen2.py +8 -15
  33. liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
  34. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  35. liger_kernel/transformers/monkey_patch.py +286 -12
  36. liger_kernel/utils.py +1 -3
  37. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
  38. liger_kernel-0.5.6.dist-info/RECORD +80 -0
  39. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  40. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
  41. liger_kernel-0.5.4.dist-info/RECORD +0 -74
  42. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  43. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  44. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
39
39
 
40
40
  return loss, chosen_rewards, rejected_rewards
41
41
 
42
- @staticmethod
42
+ @classmethod
43
43
  def forward(
44
+ cls,
44
45
  ctx,
45
46
  _input,
46
47
  weight,
@@ -52,27 +53,48 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
52
53
  label_smoothing=0.0,
53
54
  compute_nll_loss=True,
54
55
  compiled=True,
56
+ average_log_prob=False,
57
+ chunk_size=1,
55
58
  ):
56
- return LigerFusedLinearPreferenceBase.forward(
57
- ctx,
58
- _input,
59
- weight,
60
- target,
61
- bias,
62
- loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
59
+ """
60
+ Fused linear layer with CPO loss.
61
+ Args:
62
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
63
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
64
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
65
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
66
+ ignore_index (int): Index to ignore in loss computation
67
+ beta (float): Weight for the odds ratio loss
68
+ alpha (float): Weight for the alpha parameter
69
+ label_smoothing (float): Label smoothing factor
70
+ compute_nll_loss (bool): Whether to compute the NLL loss
71
+ compiled (bool): Whether to use torch compile
72
+ average_log_prob (bool): Whether to average the log probability per non-masked token
73
+ chunk_size (int): Size of chunks for processing.
74
+ Returns:
75
+ torch.Tensor: Computed loss
76
+ """
77
+ return super().forward(
78
+ cls=cls,
79
+ ctx=ctx,
80
+ _input=_input,
81
+ weight=weight,
82
+ target=target,
83
+ bias=bias,
63
84
  ignore_index=ignore_index,
64
85
  alpha=alpha,
65
86
  beta=beta,
66
87
  label_smoothing=label_smoothing,
67
88
  compute_nll_loss=compute_nll_loss,
68
- average_log_prob=False,
89
+ average_log_prob=average_log_prob,
69
90
  compiled=compiled,
91
+ chunk_size=chunk_size,
70
92
  )
71
93
 
72
94
  @staticmethod
73
95
  def backward(ctx, *grad_output):
74
96
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None, None
97
+ return *grads, None, None, None, None, None, None, None, None
76
98
 
77
99
 
78
100
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -88,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
88
110
  label_smoothing: float = 0.0,
89
111
  compute_nll_loss: bool = True,
90
112
  compiled: bool = True,
113
+ average_log_prob: bool = False,
114
+ chunk_size: int = 1,
91
115
  ):
92
116
  """
93
117
  Args:
94
118
  ignore_index (int): Index to ignore in the loss.
95
119
  beta (float): Weight for the odds ratio loss.
120
+ alpha (float): Weight for the alpha parameter.
121
+ label_smoothing (float): Label smoothing factor.
122
+ compute_nll_loss (bool): Whether to compute the NLL loss.
123
+ compiled (bool): Whether to use the torch compiled kernel.
124
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
125
+ chunk_size (int): Size of chunks for processing.
96
126
  """
97
127
  super().__init__()
98
128
  self.ignore_index = ignore_index
@@ -101,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
101
131
  self.label_smoothing = label_smoothing
102
132
  self.compute_nll_loss = compute_nll_loss
103
133
  self.compiled = compiled
134
+ self.average_log_prob = average_log_prob
135
+ self.chunk_size = chunk_size
104
136
 
105
- def forward(self, lin_weight, _input, target, bias=None):
137
+ def forward(
138
+ self,
139
+ lin_weight,
140
+ _input,
141
+ target,
142
+ bias=None,
143
+ ):
106
144
  return LigerFusedLinearCPOFunction.apply(
107
145
  _input,
108
146
  lin_weight,
@@ -114,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
114
152
  self.label_smoothing,
115
153
  self.compute_nll_loss,
116
154
  self.compiled,
155
+ self.average_log_prob,
156
+ self.chunk_size,
117
157
  )
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
52
52
  loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
53
  return loss, chosen_rewards, rejected_rewards
54
54
 
55
- @staticmethod
55
+ @classmethod
56
56
  def forward(
57
+ cls,
57
58
  ctx,
58
59
  _input,
59
60
  weight,
@@ -67,14 +68,34 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
67
68
  compute_nll_loss=False,
68
69
  compiled=True,
69
70
  use_ref_model=True,
71
+ chunk_size=1,
70
72
  ):
71
- return LigerFusedLinearPreferenceBase.forward(
73
+ """
74
+ Fused linear layer with DPO loss.
75
+ Args:
76
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
77
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
78
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
79
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
80
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
81
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
82
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
83
+ ignore_index (int): Index to ignore in loss computation
84
+ beta (float): Weight for the odds ratio loss
85
+ compute_nll_loss (bool): Whether to compute the NLL loss
86
+ compiled (bool): Whether to use torch compile
87
+ use_ref_model (bool): Whether to use a reference model
88
+ chunk_size (int): Size of chunks for processing.
89
+ Returns:
90
+ torch.Tensor: Computed loss
91
+ """
92
+ return super().forward(
93
+ cls=cls,
72
94
  ctx=ctx,
73
95
  _input=_input,
74
96
  weight=weight,
75
97
  target=target,
76
98
  bias=bias,
77
- loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
78
99
  ignore_index=ignore_index,
79
100
  beta=beta,
80
101
  compute_nll_loss=compute_nll_loss,
@@ -83,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
83
104
  ref_input=ref_input,
84
105
  ref_weight=ref_weight,
85
106
  ref_bias=ref_bias,
107
+ chunk_size=chunk_size,
86
108
  )
87
109
 
88
110
  @staticmethod
89
111
  def backward(ctx, *grad_output):
90
112
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
91
- return *grads, None, None, None, None, None, None, None, None
113
+ return *grads, None, None, None, None, None, None, None, None, None
92
114
 
93
115
 
94
116
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -103,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
103
125
  compute_nll_loss: bool = False,
104
126
  compiled: bool = True,
105
127
  use_ref_model: bool = True,
128
+ chunk_size: int = 1,
106
129
  ):
107
130
  """
108
131
  Args:
@@ -111,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
111
134
  compute_nll_loss (bool): Whether to compute the NLL loss.
112
135
  compiled (bool): Whether to use the torch compiled kernel.
113
136
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
137
+ chunk_size (int): Size of chunks for processing.
114
138
  """
115
139
  super().__init__()
116
140
  self.ignore_index = ignore_index
@@ -118,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
142
  self.compute_nll_loss = compute_nll_loss
119
143
  self.compiled = compiled
120
144
  self.use_ref_model = use_ref_model
145
+ self.chunk_size = chunk_size
121
146
 
122
147
  def forward(
123
148
  self,
@@ -142,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
167
  self.compute_nll_loss,
143
168
  self.compiled,
144
169
  self.use_ref_model,
170
+ self.chunk_size,
145
171
  )
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -115,9 +115,24 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
115
115
  student_logits_chunk /= temperature
116
116
  teacher_logits_chunk /= temperature
117
117
 
118
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
119
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
120
+ # that teacher logit is padded for some training efficiency such as
121
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122
+ teacher_vocab_size = teacher_weight.shape[0]
123
+ student_vocab_size = student_weight.shape[0]
124
+ if teacher_vocab_size > student_vocab_size:
125
+ pad_size = teacher_vocab_size - student_vocab_size
126
+ pad_tensor = torch.zeros(
127
+ (*student_logits_chunk.shape[:-1], pad_size),
128
+ dtype=student_logits_chunk.dtype,
129
+ device=student_logits_chunk.device,
130
+ )
131
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
+
118
133
  hard_loss /= full_target.shape[0]
119
134
 
120
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
135
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121
136
  soft_loss /= full_target.shape[0]
122
137
 
123
138
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -125,6 +140,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
125
140
 
126
141
  @staticmethod
127
142
  def forward(
143
+ cls,
128
144
  ctx,
129
145
  student_input,
130
146
  student_weight,
@@ -133,7 +149,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
133
149
  target,
134
150
  student_bias=None,
135
151
  teacher_bias=None,
136
- loss_fn=None,
137
152
  chunk_size=1024,
138
153
  ignore_index=-100,
139
154
  weight_hard_loss=0.5,
@@ -175,14 +190,14 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
175
190
 
176
191
  loss_func_to_call = partial(
177
192
  LigerFusedLinearDistillationBase._compute_loss,
178
- distillation_loss_fn=loss_fn,
193
+ distillation_loss_fn=cls.distillation_loss_fn,
179
194
  full_target=target,
180
195
  ignore_index=ignore_index,
181
196
  weight_hard_loss=weight_hard_loss,
182
197
  weight_soft_loss=weight_soft_loss,
183
- beta=beta,
184
198
  compute_ce_loss=compute_ce_loss,
185
199
  temperature=temperature,
200
+ beta=beta,
186
201
  **loss_kwargs,
187
202
  )
188
203
 
@@ -263,4 +278,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
263
278
  grad_weight = grad_weight * grad_output
264
279
  grad_bias = grad_bias * grad_output if grad_bias is not None else None
265
280
 
266
- return grad_input, grad_weight, None, grad_bias
281
+ return grad_input, grad_weight, None, None, None, grad_bias
@@ -0,0 +1,331 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ temperature=1.0,
36
+ compiled=True,
37
+ use_ref_model=False,
38
+ chunk_size=1,
39
+ ):
40
+ # TODO: check torch compile matmul
41
+ """Chunked forward pass for PPO loss computation.
42
+
43
+ Args:
44
+ cls: The class
45
+ ctx: Context for backward
46
+ _input: Input tensor
47
+ weight: Weight tensor
48
+ selected_token_ids: Selected token ids tensor
49
+ attention_mask: Attention mask tensor
50
+ advantages: Advantages tensor
51
+ bias: Bias tensor
52
+ ref_per_token_logps: Reference model log probs per token tensor
53
+ old_per_token_logps: Old per token log probabilities tensor
54
+ ref_input: Reference model input tensor
55
+ ref_weight: Reference model weight tensor
56
+ ref_bias: Reference model bias tensor
57
+ epsilon_low: Lower bound for clipping the importance sampling ratio
58
+ epsilon_high: Upper bound for clipping the importance sampling ratio
59
+ beta: Weight for the KL penalty
60
+ temperature: Temperature for the logits
61
+ compiled: Whether to use torch compile
62
+ use_ref_model: Whether to use a reference model
63
+ chunk_size: Size of chunks for processing in other loss modules
64
+ """
65
+ if use_ref_model:
66
+ assert ref_per_token_logps is not None or ref_input is not None, (
67
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
68
+ )
69
+ if ref_per_token_logps is not None and ref_input is not None:
70
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
71
+ # Initialize accumulators
72
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
+ grad_weight = torch.zeros_like(weight) # [V, H]
74
+ grad_inputs = []
75
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
76
+ aggregated_metrics = []
77
+
78
+ # Create a partial function with fixed arguments
79
+ compute_loss = partial(
80
+ LigerFusedLinearPPOBase._compute_chunk_loss,
81
+ ref_weight=ref_weight,
82
+ ref_bias=ref_bias,
83
+ full_attention_mask=attention_mask,
84
+ epsilon_low=epsilon_low,
85
+ epsilon_high=epsilon_high,
86
+ beta=beta,
87
+ temperature=temperature,
88
+ use_ref_model=use_ref_model,
89
+ ppo_loss_fn=cls.ppo_loss_fn,
90
+ )
91
+
92
+ def fused_fwd_bwd(
93
+ input_chunk,
94
+ selected_token_ids_chunk,
95
+ attention_mask_chunk,
96
+ advantages_chunk,
97
+ ref_per_token_logps_chunk,
98
+ old_per_token_logps_chunk,
99
+ ref_input_chunk,
100
+ ):
101
+ """Fused forward and backward for a chunk."""
102
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
103
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
104
+ input_chunk, # arg 0
105
+ weight, # arg 1
106
+ selected_token_ids_chunk, # arg 2
107
+ attention_mask_chunk, # arg 3
108
+ advantages_chunk, # arg 4
109
+ bias, # arg 5
110
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
111
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
112
+ ref_input_chunk=ref_input_chunk, # arg 8
113
+ )
114
+
115
+ def accumulate_chunk(
116
+ input_chunk,
117
+ selected_token_ids_chunk,
118
+ attention_mask_chunk,
119
+ advantages_chunk,
120
+ ref_per_token_logps_chunk=None,
121
+ old_per_token_logps_chunk=None,
122
+ ref_input_chunk=None,
123
+ ):
124
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
125
+ input_chunk,
126
+ selected_token_ids_chunk,
127
+ attention_mask_chunk,
128
+ advantages_chunk,
129
+ ref_per_token_logps_chunk,
130
+ old_per_token_logps_chunk,
131
+ ref_input_chunk,
132
+ )
133
+ if bias is not None:
134
+ grad_bias.add_(chunk_grad_bias[0])
135
+
136
+ # Accumulate gradients and loss
137
+ grad_weight.add_(chunk_grad_weight)
138
+ grad_inputs.append(chunk_grad_input)
139
+ loss_acc.add_(chunk_loss)
140
+ # Initialize storage for metrics on first chunk
141
+ if len(aggregated_metrics) == 0:
142
+ for metric in chunk_metrics:
143
+ if metric.ndim == 0:
144
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
145
+ else:
146
+ aggregated_metrics.append([])
147
+
148
+ # Accumulate metrics
149
+ for i, metric in enumerate(chunk_metrics):
150
+ if metric.ndim == 0:
151
+ aggregated_metrics[i].add_(metric)
152
+ else:
153
+ aggregated_metrics[i].append(metric)
154
+
155
+ if compiled:
156
+ # TODO: Figure out what is better to compile here
157
+ # accumulate_chunk = torch.compile(accumulate_chunk)
158
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
159
+
160
+ # Process input in chunks based on chunk_size
161
+ chunks = max(1, _input.shape[0] // chunk_size)
162
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
163
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
164
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
165
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
166
+ _ref_per_token_logps_chunks = (
167
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
168
+ if use_ref_model and ref_per_token_logps is not None
169
+ else [None] * chunks
170
+ )
171
+ _old_per_token_logps_chunks = (
172
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
173
+ if old_per_token_logps is not None
174
+ else [None] * chunks
175
+ )
176
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
177
+ _ref_input_chunks = (
178
+ torch.chunk(ref_input, chunks=chunks, dim=0)
179
+ if use_ref_model and ref_per_token_logps is None
180
+ else [None] * chunks
181
+ )
182
+
183
+ for (
184
+ input_chunk,
185
+ selected_token_ids_chunk,
186
+ attention_mask_chunk,
187
+ advantages_chunk,
188
+ ref_per_token_logps_chunk,
189
+ old_per_token_logps_chunk,
190
+ ref_input_chunk,
191
+ ) in zip(
192
+ _input_chunks,
193
+ _selected_token_ids_chunks,
194
+ _attention_mask_chunks,
195
+ _advantages_chunks,
196
+ _ref_per_token_logps_chunks,
197
+ _old_per_token_logps_chunks,
198
+ _ref_input_chunks,
199
+ ):
200
+ # Mark dynamic dimensions
201
+ torch._dynamo.mark_dynamic(input_chunk, 1)
202
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
203
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
204
+ if ref_per_token_logps_chunk is not None:
205
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
206
+ if ref_input_chunk is not None:
207
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
208
+ if old_per_token_logps_chunk is not None:
209
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
210
+
211
+ accumulate_chunk(
212
+ input_chunk,
213
+ selected_token_ids_chunk,
214
+ attention_mask_chunk,
215
+ advantages_chunk,
216
+ ref_per_token_logps_chunk,
217
+ old_per_token_logps_chunk,
218
+ ref_input_chunk,
219
+ )
220
+
221
+ # Combine gradients
222
+ grad_input = torch.cat(grad_inputs, dim=0)
223
+
224
+ # Save for backward
225
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
226
+
227
+ # Finalize metrics
228
+ final_metrics = []
229
+ for metric in aggregated_metrics:
230
+ if isinstance(metric, list):
231
+ final_metrics.append(torch.cat(metric, dim=0))
232
+ else:
233
+ final_metrics.append(metric)
234
+
235
+ return loss_acc, tuple(final_metrics)
236
+
237
+ @staticmethod
238
+ def _compute_chunk_loss(
239
+ input_chunk,
240
+ weight,
241
+ selected_token_ids_chunk,
242
+ attention_mask_chunk,
243
+ advantages_chunk,
244
+ bias=None,
245
+ ref_per_token_logps_chunk=None,
246
+ old_per_token_logps_chunk=None,
247
+ ref_input_chunk=None,
248
+ ref_weight=None,
249
+ ref_bias=None,
250
+ full_attention_mask=None,
251
+ epsilon_low=0.2,
252
+ epsilon_high=0.2,
253
+ beta=0.04,
254
+ temperature=1.0,
255
+ use_ref_model=False,
256
+ ppo_loss_fn=None,
257
+ ):
258
+ """Compute loss for a single chunk."""
259
+ # Get policy log probabilities using chunk_forward
260
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
261
+
262
+ # Get reference log probabilities if needed
263
+ ref_log_probs = None
264
+ if use_ref_model and ref_per_token_logps_chunk is None:
265
+ with torch.no_grad():
266
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
267
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
268
+ )
269
+
270
+ # Compute chunk loss and metrics using the provided loss function
271
+ chunk_loss, chunk_metrics = ppo_loss_fn(
272
+ log_probs=log_probs,
273
+ selected_token_ids=selected_token_ids_chunk,
274
+ attention_mask=attention_mask_chunk,
275
+ advantages=advantages_chunk,
276
+ full_attention_mask=full_attention_mask,
277
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
278
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
279
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
280
+ epsilon_low=epsilon_low,
281
+ epsilon_high=epsilon_high,
282
+ beta=beta,
283
+ )
284
+
285
+ return chunk_loss, chunk_metrics
286
+
287
+ @staticmethod
288
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
289
+ """Forward pass computation for a single chunk without explicit reshaping."""
290
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
291
+ logits = torch.matmul(input_chunk, weight.t())
292
+ if bias is not None:
293
+ logits = logits + bias # Broadcasts bias to [B, T, V]
294
+ if temperature != 1.0:
295
+ logits = logits / temperature
296
+
297
+ # Compute log probabilities using softmax over the last dimension
298
+ log_probs = F.log_softmax(logits.float(), dim=-1)
299
+
300
+ return log_probs, logits
301
+
302
+ @staticmethod
303
+ def backward(ctx, grad_output, *grad_metrics):
304
+ """Backward pass for PPO loss."""
305
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
306
+ if grad_output != 1.0:
307
+ grad_input = grad_input * grad_output
308
+ grad_weight = grad_weight * grad_output
309
+ if grad_bias is not None:
310
+ grad_bias = grad_bias * grad_output
311
+
312
+ return (
313
+ grad_input,
314
+ grad_weight,
315
+ None, # grad_selected_token_ids
316
+ None, # grad_attention_mask
317
+ None, # grad_advantages
318
+ grad_bias,
319
+ None, # grad_ref_per_token_logps
320
+ None, # grad_old_per_token_logps
321
+ None, # grad_ref_input
322
+ None, # grad_ref_weight
323
+ None, # grad_ref_bias
324
+ None, # grad_epsilon_low
325
+ None, # grad_epsilon_high
326
+ None, # grad_beta
327
+ None, # grad_temperature
328
+ None, # grad_compiled
329
+ None, # grad_use_ref_model
330
+ None, # grad_chunk_size
331
+ )
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
16
16
 
17
17
  @staticmethod
18
18
  def forward(
19
+ cls,
19
20
  ctx,
20
21
  _input,
21
22
  weight,
22
23
  target,
23
24
  bias=None,
24
- loss_fn=None,
25
25
  chunk_size=1,
26
26
  ignore_index=-100,
27
27
  alpha=1.0,
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
89
89
 
90
90
  compute_loss = partial(
91
91
  LigerFusedLinearPreferenceBase._compute_loss,
92
- preference_loss_fn=loss_fn,
92
+ preference_loss_fn=cls.preference_loss_fn,
93
93
  ignore_index=ignore_index,
94
94
  alpha=alpha,
95
95
  beta=beta,