liger-kernel 0.5.4__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,20 +30,24 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
30
30
  jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
31
  return jsd_loss
32
32
 
33
- @staticmethod
33
+ @classmethod
34
34
  def forward(
35
+ cls,
35
36
  ctx,
36
37
  student_input: torch.Tensor,
37
38
  student_weight: torch.Tensor,
38
39
  teacher_input: torch.Tensor,
39
40
  teacher_weight: torch.Tensor,
40
41
  true_labels: torch.LongTensor,
42
+ student_bias: torch.Tensor,
43
+ teacher_bias: torch.Tensor,
41
44
  weight_hard_loss: float = 0.5,
42
45
  weight_soft_loss: float = 0.5,
43
46
  beta: float = 0.5,
44
47
  ignore_index: int = -100,
45
48
  temperature: float = 1.0,
46
49
  compiled: bool = True,
50
+ chunk_size: int = 1024,
47
51
  ):
48
52
  """
49
53
  Fused linear layer with JSD distillation loss.
@@ -59,18 +63,21 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
59
63
  ignore_index (int): Index to ignore in loss computation
60
64
  temperature (float): Temperature for softening/sharpening distributions
61
65
  compiled (bool): Whether to use torch compile
66
+ chunk_size (int): Size of chunks for processing.
62
67
  Returns:
63
68
  torch.Tensor: Computed loss
64
69
  """
65
- return LigerFusedLinearDistillationBase.forward(
70
+ return super().forward(
71
+ cls=cls,
66
72
  ctx=ctx,
67
73
  student_input=student_input,
68
74
  student_weight=student_weight,
69
75
  teacher_input=teacher_input,
70
76
  teacher_weight=teacher_weight,
71
77
  target=true_labels,
72
- loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73
- chunk_size=1,
78
+ student_bias=student_bias,
79
+ teacher_bias=teacher_bias,
80
+ chunk_size=chunk_size,
74
81
  weight_hard_loss=weight_hard_loss,
75
82
  weight_soft_loss=weight_soft_loss,
76
83
  beta=beta,
@@ -81,9 +88,19 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
81
88
 
82
89
  @staticmethod
83
90
  def backward(ctx, grad_output):
84
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
91
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
85
92
 
86
- return (*grads, None, None, None, None, None, None, None)
93
+ return (
94
+ *grads,
95
+ None, # teacher_bias
96
+ None, # weight_hard_loss
97
+ None, # weight_soft_loss
98
+ None, # beta
99
+ None, # ignore_index
100
+ None, # temperature
101
+ None, # compiled
102
+ None, # chunk_size
103
+ )
87
104
 
88
105
 
89
106
  class LigerFusedLinearJSDLoss(torch.nn.Module):
@@ -99,6 +116,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
99
116
  ignore_index: int = -100,
100
117
  temperature: float = 1.0,
101
118
  compiled: bool = True,
119
+ chunk_size: int = 1024,
102
120
  ):
103
121
  """
104
122
  Args:
@@ -108,6 +126,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
108
126
  temperature (float): Temperature for softening distributions
109
127
  compiled (bool): Whether to use torch compile
110
128
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
129
+ chunk_size (int): Size of chunks for processing.
111
130
  """
112
131
  super().__init__()
113
132
  assert temperature != 0, "Temperature cannot be 0."
@@ -117,6 +136,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
117
136
  self.temperature = temperature
118
137
  self.compiled = compiled
119
138
  self.beta = beta
139
+ self.chunk_size = chunk_size
120
140
 
121
141
  def forward(
122
142
  self,
@@ -125,6 +145,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
125
145
  teacher_input: torch.Tensor,
126
146
  teacher_weight: torch.Tensor,
127
147
  true_labels: torch.LongTensor,
148
+ student_bias: torch.Tensor,
149
+ teacher_bias: torch.Tensor,
128
150
  ) -> torch.Tensor:
129
151
  """
130
152
  Compute the JSD distillation loss.
@@ -145,10 +167,13 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
145
167
  teacher_input,
146
168
  teacher_weight,
147
169
  true_labels,
170
+ student_bias,
171
+ teacher_bias,
148
172
  self.weight_hard_loss,
149
173
  self.weight_soft_loss,
150
174
  self.beta,
151
175
  self.ignore_index,
152
176
  self.temperature,
153
177
  self.compiled,
178
+ self.chunk_size,
154
179
  )
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
7
7
  class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
8
  @staticmethod
9
9
  def preference_loss_fn(
10
- average_log_prob_chunk,
10
+ log_prob_chunk,
11
11
  preference_labels_chunk,
12
12
  full_target,
13
- ref_average_log_prob_chunk=None,
13
+ ref_log_prob_chunk=None,
14
14
  beta=0.1,
15
15
  kl=None,
16
16
  ):
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
46
+ log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
47
  preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
49
+ ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
50
  beta: Weight for the KTO loss
51
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
52
52
  Returns:
53
53
  - loss: The KTO loss value
54
54
  """
55
- if ref_average_log_prob_chunk is not None:
56
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_log_prob_chunk is not None:
56
+ logratios_chunk = log_prob_chunk - ref_log_prob_chunk
57
57
  else:
58
- logratios_chunk = average_log_prob_chunk
59
-
58
+ logratios_chunk = log_prob_chunk
60
59
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
60
  if kl is not None:
62
61
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
63
62
  else:
64
63
  losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
65
64
 
66
- return losses.sum() / (full_target.shape[0])
65
+ rewards = beta * logratios_chunk
66
+ chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
67
+ rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
67
68
 
68
- @staticmethod
69
+ return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
70
+
71
+ @classmethod
69
72
  def forward(
73
+ cls,
70
74
  ctx,
71
75
  _input,
72
76
  weight,
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
81
85
  beta=0.1,
82
86
  compiled=True,
83
87
  use_ref_model=True,
88
+ average_log_prob=False,
89
+ chunk_size=1,
84
90
  ):
85
- return LigerFusedLinearUnpairedPreferenceBase.forward(
91
+ """
92
+ Fused linear layer with KTO loss.
93
+ Args:
94
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
95
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
96
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
97
+ preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
98
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
99
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
100
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
101
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
102
+ kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
103
+ ignore_index (int): Index to ignore in loss computation
104
+ beta (float): Temperature parameter for the KTO loss
105
+ compiled (bool): Whether to use torch compile
106
+ use_ref_model (bool): Whether to use a reference model
107
+ average_log_prob (bool): Whether to average the log probability per non-masked token
108
+ chunk_size (int): Size of chunks for processing
109
+ Returns:
110
+ torch.Tensor: Computed loss
111
+ """
112
+ return super().forward(
113
+ cls=cls,
86
114
  ctx=ctx,
87
115
  _input=_input,
88
116
  weight=weight,
89
117
  target=target,
90
118
  preference_labels=preference_labels,
91
119
  bias=bias,
92
- loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
93
120
  ignore_index=ignore_index,
94
121
  beta=beta,
95
122
  compiled=compiled,
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
97
124
  ref_input=ref_input,
98
125
  ref_weight=ref_weight,
99
126
  ref_bias=ref_bias,
127
+ average_log_prob=average_log_prob,
100
128
  kl=kl,
129
+ chunk_size=chunk_size,
101
130
  )
102
131
 
103
132
  @staticmethod
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
115
144
  None,
116
145
  None,
117
146
  None,
147
+ None,
118
148
  )
119
149
 
120
150
 
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
129
159
  beta: float = 0.1,
130
160
  compiled: bool = True,
131
161
  use_ref_model: bool = False,
162
+ average_log_prob: bool = False,
163
+ chunk_size: int = 1,
132
164
  ):
133
165
  """
134
166
  Args:
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
136
168
  beta (float): Temperature parameter for the KTO loss
137
169
  compiled (bool): Whether to use compiled operations
138
170
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
171
+ average_log_prob (bool): Whether to average the log probability per non-masked token
172
+ chunk_size (int): Size of chunks for processing
139
173
  """
140
174
  super().__init__()
141
175
  self.ignore_index = ignore_index
142
176
  self.beta = beta
143
177
  self.compiled = compiled
144
178
  self.use_ref_model = use_ref_model
179
+ self.average_log_prob = average_log_prob
180
+ self.chunk_size = chunk_size
145
181
 
146
182
  def forward(
147
183
  self,
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
169
205
  self.beta,
170
206
  self.compiled,
171
207
  self.use_ref_model,
208
+ self.average_log_prob,
209
+ self.chunk_size,
172
210
  )
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
42
42
 
43
43
  return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
44
44
 
45
- @staticmethod
45
+ @classmethod
46
46
  def forward(
47
+ cls,
47
48
  ctx,
48
49
  _input,
49
50
  weight,
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
54
55
  compute_nll_loss=True,
55
56
  nll_target=None,
56
57
  compiled=True,
58
+ chunk_size=1,
57
59
  ):
58
- return LigerFusedLinearPreferenceBase.forward(
60
+ """
61
+ Fused linear layer with ORPO loss.
62
+ Args:
63
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
64
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
65
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
66
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
67
+ ignore_index (int): Index to ignore in loss computation
68
+ beta (float): Weight for the odds ratio loss
69
+ compute_nll_loss (bool): Whether to compute the NLL loss
70
+ nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
71
+ compiled (bool): Whether to use torch compile
72
+ chunk_size (int): Size of chunks for processing
73
+ Returns:
74
+ torch.Tensor: Computed loss
75
+ """
76
+ return super().forward(
77
+ cls=cls,
59
78
  ctx=ctx,
60
79
  _input=_input,
61
80
  weight=weight,
62
81
  target=target,
63
82
  bias=bias,
64
- loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
65
83
  ignore_index=ignore_index,
66
84
  beta=beta,
67
85
  compute_nll_loss=compute_nll_loss,
68
86
  nll_target=nll_target,
69
87
  compiled=compiled,
88
+ chunk_size=chunk_size,
70
89
  )
71
90
 
72
91
  @staticmethod
73
92
  def backward(ctx, *grad_output):
74
93
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None
94
+ return *grads, None, None, None, None, None, None
76
95
 
77
96
 
78
97
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
86
105
  beta: float = 0.1,
87
106
  compute_nll_loss: bool = True,
88
107
  compiled: bool = True,
108
+ chunk_size: int = 1,
89
109
  ):
90
110
  """
91
111
  Args:
92
112
  ignore_index (int): Index to ignore in the loss.
93
113
  beta (float): Weight for the odds ratio loss.
114
+ compute_nll_loss (bool): Whether to compute the NLL loss.
115
+ compiled (bool): Whether to use the torch compiled kernel.
116
+ chunk_size (int): Size of chunks for processing.
94
117
  """
95
118
  super().__init__()
96
119
  self.ignore_index = ignore_index
97
120
  self.beta = beta
98
121
  self.compute_nll_loss = compute_nll_loss
99
122
  self.compiled = compiled
123
+ self.chunk_size = chunk_size
100
124
 
101
- def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
125
+ def forward(
126
+ self,
127
+ lin_weight,
128
+ _input,
129
+ target,
130
+ bias=None,
131
+ nll_target=None,
132
+ ):
102
133
  return LigerFusedLinearORPOFunction.apply(
103
134
  _input,
104
135
  lin_weight,
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
109
140
  self.compute_nll_loss,
110
141
  nll_target,
111
142
  self.compiled,
143
+ self.chunk_size,
112
144
  )
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
47
47
 
48
48
  return loss, chosen_rewards, rejected_rewards
49
49
 
50
- @staticmethod
50
+ @classmethod
51
51
  def forward(
52
+ cls,
52
53
  ctx,
53
54
  _input,
54
55
  weight,
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
62
  compute_nll_loss=False,
62
63
  compiled=True,
63
64
  gamma=0.5,
65
+ chunk_size=1,
64
66
  ):
65
- return LigerFusedLinearPreferenceBase.forward(
66
- ctx,
67
- _input,
68
- weight,
69
- target,
70
- bias,
71
- loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
72
- compute_nll_loss=compute_nll_loss,
67
+ """
68
+ Fused linear layer with SimPO loss.
69
+ Args:
70
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
71
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
72
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
73
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
74
+ ignore_index (int): Index to ignore in loss computation
75
+ beta (float): Weight for the odds ratio loss
76
+ alpha (float): Weight for the alpha parameter
77
+ label_smoothing (float): Label smoothing factor
78
+ compute_nll_loss (bool): Whether to compute the NLL loss
79
+ compiled (bool): Whether to use torch compile
80
+ gamma (float): Weight for the gamma parameter
81
+ chunk_size (int): Size of chunks for processing
82
+ Returns:
83
+ torch.Tensor: Computed loss
84
+ """
85
+ return super().forward(
86
+ cls=cls,
87
+ ctx=ctx,
88
+ _input=_input,
89
+ weight=weight,
90
+ target=target,
91
+ bias=bias,
73
92
  ignore_index=ignore_index,
74
93
  alpha=alpha,
75
94
  beta=beta,
76
95
  label_smoothing=label_smoothing,
96
+ compute_nll_loss=compute_nll_loss,
77
97
  compiled=compiled,
78
98
  gamma=gamma,
99
+ chunk_size=chunk_size,
79
100
  )
80
101
 
81
102
  @staticmethod
82
103
  def backward(ctx, *grad_output):
83
104
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
84
- return *grads, None, None, None, None, None, None, None
105
+ return *grads, None, None, None, None, None, None, None, None
85
106
 
86
107
 
87
108
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
98
119
  compute_nll_loss: bool = True,
99
120
  compiled: bool = True,
100
121
  gamma: float = 0.5,
122
+ chunk_size: int = 1,
101
123
  ):
102
124
  """
103
125
  Args:
104
126
  ignore_index (int): Index to ignore in the loss.
105
127
  beta (float): Weight for the odds ratio loss.
128
+ alpha (float): Weight for the alpha parameter.
129
+ label_smoothing (float): Label smoothing factor.
130
+ compute_nll_loss (bool): Whether to compute the NLL loss.
131
+ compiled (bool): Whether to use the torch compiled kernel.
132
+ gamma (float): Weight for the gamma parameter.
133
+ chunk_size (int): Size of chunks for processing.
106
134
  """
107
135
  super().__init__()
108
136
  self.ignore_index = ignore_index
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
112
140
  self.compute_nll_loss = compute_nll_loss
113
141
  self.compiled = compiled
114
142
  self.gamma = gamma
143
+ self.chunk_size = chunk_size
115
144
 
116
- def forward(self, lin_weight, _input, target, bias=None):
145
+ def forward(
146
+ self,
147
+ lin_weight,
148
+ _input,
149
+ target,
150
+ bias=None,
151
+ ):
117
152
  return LigerFusedLinearSimPOFunction.apply(
118
153
  _input,
119
154
  lin_weight,
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
126
161
  self.compute_nll_loss,
127
162
  self.compiled,
128
163
  self.gamma,
164
+ self.chunk_size,
129
165
  )
@@ -285,6 +285,10 @@ def cross_entropy_forward(
285
285
 
286
286
  target_mask = target != ignore_index
287
287
  n_non_ignore = target_mask.sum().item()
288
+ assert (target * target_mask).max() < _input.shape[-1], (
289
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
290
+ )
291
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
288
292
  sum_non_ignore_weight = n_non_ignore
289
293
  weight_sum = 0.0
290
294
  if weight is not None:
@@ -17,6 +17,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama
17
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
18
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
19
19
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
20
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
20
21
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
21
22
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
22
23
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401