liger-kernel 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cpo_loss.py +51 -11
  3. liger_kernel/chunked_loss/dpo_loss.py +30 -4
  4. liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
  5. liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  6. liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
  7. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  8. liger_kernel/chunked_loss/grpo_loss.py +194 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +31 -6
  10. liger_kernel/chunked_loss/kto_loss.py +53 -15
  11. liger_kernel/chunked_loss/orpo_loss.py +37 -5
  12. liger_kernel/chunked_loss/simpo_loss.py +47 -11
  13. liger_kernel/ops/cross_entropy.py +7 -3
  14. liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
  15. liger_kernel/ops/fused_linear_jsd.py +3 -3
  16. liger_kernel/ops/jsd.py +3 -3
  17. liger_kernel/ops/layer_norm.py +20 -7
  18. liger_kernel/ops/tvd.py +207 -0
  19. liger_kernel/ops/utils.py +1 -2
  20. liger_kernel/transformers/__init__.py +4 -0
  21. liger_kernel/transformers/cross_entropy.py +3 -3
  22. liger_kernel/transformers/functional.py +17 -0
  23. liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
  24. liger_kernel/transformers/group_norm.py +6 -6
  25. liger_kernel/transformers/model/olmo2.py +124 -0
  26. liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
  27. liger_kernel/transformers/monkey_patch.py +239 -27
  28. liger_kernel/transformers/tvd.py +13 -0
  29. liger_kernel/utils.py +48 -1
  30. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/METADATA +19 -4
  31. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/RECORD +35 -29
  32. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/WHEEL +1 -1
  33. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/LICENSE +0 -0
  34. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/NOTICE +0 -0
  35. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
7
7
  class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
8
  @staticmethod
9
9
  def preference_loss_fn(
10
- average_log_prob_chunk,
10
+ log_prob_chunk,
11
11
  preference_labels_chunk,
12
12
  full_target,
13
- ref_average_log_prob_chunk=None,
13
+ ref_log_prob_chunk=None,
14
14
  beta=0.1,
15
15
  kl=None,
16
16
  ):
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- chosen_logps: Log probabilities of chosen tokens (batch_size,)
47
- rejected_logps: Log probabilities of rejected tokens (batch_size,)
46
+ log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
50
- ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
51
- beta: Weight for the direct preference loss
49
+ ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
52
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
53
52
  Returns:
54
- Tuple of (loss, chosen_rewards, rejected_rewards):
55
53
  - loss: The KTO loss value
56
- - chosen_rewards: Reward signals for chosen responses (detached)
57
- - rejected_rewards: Reward signals for rejected responses (detached)
58
54
  """
59
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_log_prob_chunk is not None:
56
+ logratios_chunk = log_prob_chunk - ref_log_prob_chunk
57
+ else:
58
+ logratios_chunk = log_prob_chunk
60
59
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
60
  if kl is not None:
62
61
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
63
62
  else:
64
63
  losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
65
64
 
66
- return losses.sum() / (full_target.shape[0])
65
+ rewards = beta * logratios_chunk
66
+ chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
67
+ rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
67
68
 
68
- @staticmethod
69
+ return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
70
+
71
+ @classmethod
69
72
  def forward(
73
+ cls,
70
74
  ctx,
71
75
  _input,
72
76
  weight,
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
81
85
  beta=0.1,
82
86
  compiled=True,
83
87
  use_ref_model=True,
88
+ average_log_prob=False,
89
+ chunk_size=1,
84
90
  ):
85
- return LigerFusedLinearUnpairedPreferenceBase.forward(
91
+ """
92
+ Fused linear layer with KTO loss.
93
+ Args:
94
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
95
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
96
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
97
+ preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
98
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
99
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
100
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
101
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
102
+ kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
103
+ ignore_index (int): Index to ignore in loss computation
104
+ beta (float): Temperature parameter for the KTO loss
105
+ compiled (bool): Whether to use torch compile
106
+ use_ref_model (bool): Whether to use a reference model
107
+ average_log_prob (bool): Whether to average the log probability per non-masked token
108
+ chunk_size (int): Size of chunks for processing
109
+ Returns:
110
+ torch.Tensor: Computed loss
111
+ """
112
+ return super().forward(
113
+ cls=cls,
86
114
  ctx=ctx,
87
115
  _input=_input,
88
116
  weight=weight,
89
117
  target=target,
90
118
  preference_labels=preference_labels,
91
119
  bias=bias,
92
- loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
93
120
  ignore_index=ignore_index,
94
121
  beta=beta,
95
122
  compiled=compiled,
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
97
124
  ref_input=ref_input,
98
125
  ref_weight=ref_weight,
99
126
  ref_bias=ref_bias,
127
+ average_log_prob=average_log_prob,
100
128
  kl=kl,
129
+ chunk_size=chunk_size,
101
130
  )
102
131
 
103
132
  @staticmethod
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
115
144
  None,
116
145
  None,
117
146
  None,
147
+ None,
118
148
  )
119
149
 
120
150
 
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
129
159
  beta: float = 0.1,
130
160
  compiled: bool = True,
131
161
  use_ref_model: bool = False,
162
+ average_log_prob: bool = False,
163
+ chunk_size: int = 1,
132
164
  ):
133
165
  """
134
166
  Args:
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
136
168
  beta (float): Temperature parameter for the KTO loss
137
169
  compiled (bool): Whether to use compiled operations
138
170
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
171
+ average_log_prob (bool): Whether to average the log probability per non-masked token
172
+ chunk_size (int): Size of chunks for processing
139
173
  """
140
174
  super().__init__()
141
175
  self.ignore_index = ignore_index
142
176
  self.beta = beta
143
177
  self.compiled = compiled
144
178
  self.use_ref_model = use_ref_model
179
+ self.average_log_prob = average_log_prob
180
+ self.chunk_size = chunk_size
145
181
 
146
182
  def forward(
147
183
  self,
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
169
205
  self.beta,
170
206
  self.compiled,
171
207
  self.use_ref_model,
208
+ self.average_log_prob,
209
+ self.chunk_size,
172
210
  )
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
42
42
 
43
43
  return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
44
44
 
45
- @staticmethod
45
+ @classmethod
46
46
  def forward(
47
+ cls,
47
48
  ctx,
48
49
  _input,
49
50
  weight,
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
54
55
  compute_nll_loss=True,
55
56
  nll_target=None,
56
57
  compiled=True,
58
+ chunk_size=1,
57
59
  ):
58
- return LigerFusedLinearPreferenceBase.forward(
60
+ """
61
+ Fused linear layer with ORPO loss.
62
+ Args:
63
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
64
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
65
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
66
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
67
+ ignore_index (int): Index to ignore in loss computation
68
+ beta (float): Weight for the odds ratio loss
69
+ compute_nll_loss (bool): Whether to compute the NLL loss
70
+ nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
71
+ compiled (bool): Whether to use torch compile
72
+ chunk_size (int): Size of chunks for processing
73
+ Returns:
74
+ torch.Tensor: Computed loss
75
+ """
76
+ return super().forward(
77
+ cls=cls,
59
78
  ctx=ctx,
60
79
  _input=_input,
61
80
  weight=weight,
62
81
  target=target,
63
82
  bias=bias,
64
- loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
65
83
  ignore_index=ignore_index,
66
84
  beta=beta,
67
85
  compute_nll_loss=compute_nll_loss,
68
86
  nll_target=nll_target,
69
87
  compiled=compiled,
88
+ chunk_size=chunk_size,
70
89
  )
71
90
 
72
91
  @staticmethod
73
92
  def backward(ctx, *grad_output):
74
93
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None
94
+ return *grads, None, None, None, None, None, None
76
95
 
77
96
 
78
97
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
86
105
  beta: float = 0.1,
87
106
  compute_nll_loss: bool = True,
88
107
  compiled: bool = True,
108
+ chunk_size: int = 1,
89
109
  ):
90
110
  """
91
111
  Args:
92
112
  ignore_index (int): Index to ignore in the loss.
93
113
  beta (float): Weight for the odds ratio loss.
114
+ compute_nll_loss (bool): Whether to compute the NLL loss.
115
+ compiled (bool): Whether to use the torch compiled kernel.
116
+ chunk_size (int): Size of chunks for processing.
94
117
  """
95
118
  super().__init__()
96
119
  self.ignore_index = ignore_index
97
120
  self.beta = beta
98
121
  self.compute_nll_loss = compute_nll_loss
99
122
  self.compiled = compiled
123
+ self.chunk_size = chunk_size
100
124
 
101
- def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
125
+ def forward(
126
+ self,
127
+ lin_weight,
128
+ _input,
129
+ target,
130
+ bias=None,
131
+ nll_target=None,
132
+ ):
102
133
  return LigerFusedLinearORPOFunction.apply(
103
134
  _input,
104
135
  lin_weight,
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
109
140
  self.compute_nll_loss,
110
141
  nll_target,
111
142
  self.compiled,
143
+ self.chunk_size,
112
144
  )
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
47
47
 
48
48
  return loss, chosen_rewards, rejected_rewards
49
49
 
50
- @staticmethod
50
+ @classmethod
51
51
  def forward(
52
+ cls,
52
53
  ctx,
53
54
  _input,
54
55
  weight,
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
62
  compute_nll_loss=False,
62
63
  compiled=True,
63
64
  gamma=0.5,
65
+ chunk_size=1,
64
66
  ):
65
- return LigerFusedLinearPreferenceBase.forward(
66
- ctx,
67
- _input,
68
- weight,
69
- target,
70
- bias,
71
- loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
72
- compute_nll_loss=compute_nll_loss,
67
+ """
68
+ Fused linear layer with SimPO loss.
69
+ Args:
70
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
71
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
72
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
73
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
74
+ ignore_index (int): Index to ignore in loss computation
75
+ beta (float): Weight for the odds ratio loss
76
+ alpha (float): Weight for the alpha parameter
77
+ label_smoothing (float): Label smoothing factor
78
+ compute_nll_loss (bool): Whether to compute the NLL loss
79
+ compiled (bool): Whether to use torch compile
80
+ gamma (float): Weight for the gamma parameter
81
+ chunk_size (int): Size of chunks for processing
82
+ Returns:
83
+ torch.Tensor: Computed loss
84
+ """
85
+ return super().forward(
86
+ cls=cls,
87
+ ctx=ctx,
88
+ _input=_input,
89
+ weight=weight,
90
+ target=target,
91
+ bias=bias,
73
92
  ignore_index=ignore_index,
74
93
  alpha=alpha,
75
94
  beta=beta,
76
95
  label_smoothing=label_smoothing,
96
+ compute_nll_loss=compute_nll_loss,
77
97
  compiled=compiled,
78
98
  gamma=gamma,
99
+ chunk_size=chunk_size,
79
100
  )
80
101
 
81
102
  @staticmethod
82
103
  def backward(ctx, *grad_output):
83
104
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
84
- return *grads, None, None, None, None, None, None, None
105
+ return *grads, None, None, None, None, None, None, None, None
85
106
 
86
107
 
87
108
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
98
119
  compute_nll_loss: bool = True,
99
120
  compiled: bool = True,
100
121
  gamma: float = 0.5,
122
+ chunk_size: int = 1,
101
123
  ):
102
124
  """
103
125
  Args:
104
126
  ignore_index (int): Index to ignore in the loss.
105
127
  beta (float): Weight for the odds ratio loss.
128
+ alpha (float): Weight for the alpha parameter.
129
+ label_smoothing (float): Label smoothing factor.
130
+ compute_nll_loss (bool): Whether to compute the NLL loss.
131
+ compiled (bool): Whether to use the torch compiled kernel.
132
+ gamma (float): Weight for the gamma parameter.
133
+ chunk_size (int): Size of chunks for processing.
106
134
  """
107
135
  super().__init__()
108
136
  self.ignore_index = ignore_index
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
112
140
  self.compute_nll_loss = compute_nll_loss
113
141
  self.compiled = compiled
114
142
  self.gamma = gamma
143
+ self.chunk_size = chunk_size
115
144
 
116
- def forward(self, lin_weight, _input, target, bias=None):
145
+ def forward(
146
+ self,
147
+ lin_weight,
148
+ _input,
149
+ target,
150
+ bias=None,
151
+ ):
117
152
  return LigerFusedLinearSimPOFunction.apply(
118
153
  _input,
119
154
  lin_weight,
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
126
161
  self.compute_nll_loss,
127
162
  self.compiled,
128
163
  self.gamma,
164
+ self.chunk_size,
129
165
  )
@@ -285,13 +285,17 @@ def cross_entropy_forward(
285
285
 
286
286
  target_mask = target != ignore_index
287
287
  n_non_ignore = target_mask.sum().item()
288
+ assert (target * target_mask).max() < _input.shape[-1], (
289
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
290
+ )
291
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
288
292
  sum_non_ignore_weight = n_non_ignore
289
293
  weight_sum = 0.0
290
294
  if weight is not None:
291
295
  assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
292
- assert torch.is_floating_point(
293
- weight
294
- ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
296
+ assert torch.is_floating_point(weight), (
297
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
298
+ )
295
299
  sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
296
300
  weight_sum = weight.sum().item()
297
301
  # ensure weight is contiguous
@@ -58,9 +58,9 @@ def fused_linear_cross_entropy_forward(
58
58
  ce_weight_sum = 0.0
59
59
  if ce_weight is not None:
60
60
  assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
61
- assert torch.is_floating_point(
62
- ce_weight
63
- ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
61
+ assert torch.is_floating_point(ce_weight), (
62
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
63
+ )
64
64
  total_sum_non_ignore_ce_weight = (
65
65
  torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
66
66
  )
@@ -195,9 +195,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
195
195
  """
196
196
  has_label = False
197
197
  if shift_labels is not None:
198
- assert shift_labels.shape == (
199
- teacher_input.shape[0],
200
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
198
+ assert shift_labels.shape == (teacher_input.shape[0],), (
199
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
200
+ )
201
201
  shift_labels = shift_labels.contiguous()
202
202
  has_label = True
203
203
 
liger_kernel/ops/jsd.py CHANGED
@@ -157,9 +157,9 @@ class LigerJSDFunction(torch.autograd.Function):
157
157
  """
158
158
  has_label = False
159
159
  if shift_labels is not None:
160
- assert shift_labels.shape == (
161
- _input.shape[0],
162
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
160
+ assert shift_labels.shape == (_input.shape[0],), (
161
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
162
+ )
163
163
  shift_labels = shift_labels.contiguous()
164
164
  has_label = True
165
165
 
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
57
57
  B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
58
 
59
59
  mean = tl.sum(X_row, axis=0) / n_cols
60
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
61
62
  rstd = rsqrt(var + eps)
62
63
 
63
64
  tl.store(Mean_ptr, mean)
64
65
  tl.store(RSTD_ptr, rstd)
65
66
 
66
- Y_row = (X_row - mean) * rstd * W_row + B_row
67
+ Y_row = Xmm * rstd * W_row + B_row
67
68
 
68
69
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
69
70
 
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
147
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
148
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
149
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
- assert (
151
- X.shape[1] == W.shape[0]
152
- ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
153
156
 
154
157
  _layer_norm_forward_kernel[(n_rows,)](
155
158
  Y,
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
190
193
 
191
194
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
192
195
  if n_cols > BLOCK_SIZE:
193
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
196
+ raise RuntimeError(
197
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
+ )
194
199
 
195
200
  rows_per_program = math.ceil(n_rows / sm_count)
196
201
  grid = (sm_count,)
197
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
202
+ triton_dtype = (
203
+ tl.float32
204
+ if X.dtype == torch.float32
205
+ else tl.bfloat16
206
+ if X.dtype == torch.bfloat16
207
+ else tl.float16
208
+ if X.dtype == torch.float16
209
+ else tl.float32 # fallback to float32 for other types
210
+ )
198
211
  _layer_norm_backward_kernel[grid](
199
212
  X,
200
213
  W,
@@ -0,0 +1,207 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None