liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -1,66 +1,92 @@
1
1
  import torch
2
2
 
3
- from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
4
 
5
5
 
6
- class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
6
+ def k3_loss_fn(log_p, log_q):
7
+ # computes k3 estimate of KL[q, p]
8
+ # ref: http://joschu.net/blog/kl-approx.html
9
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
10
+
11
+
12
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
13
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
14
+
15
+
16
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
7
17
  @staticmethod
8
- def rlhf_loss_fn(
18
+ def ppo_loss_fn(
9
19
  log_probs,
20
+ selected_token_ids,
10
21
  attention_mask,
11
- rewards,
12
- ref_log_probs=None,
13
- beta=0.1,
22
+ advantages,
23
+ full_attention_mask,
24
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
25
+ old_per_token_logps=None,
26
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
27
+ epsilon_low=0.2,
28
+ epsilon_high=0.2,
29
+ beta=0.04,
30
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
31
+ max_completion_length=None, # Required for dr_grpo
14
32
  **kwargs,
15
33
  ):
16
34
  """GRPO Loss Function matching GRPOTrainer implementation."""
17
- # Get chosen token probabilities
18
- chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
- chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
35
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
20
36
  -1
21
37
  ) # (batch_size, seq_len)
22
38
 
23
39
  # Get reference model probabilities
24
- if ref_log_probs is not None:
25
- with torch.no_grad():
26
- ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
- else:
28
- ref_token_logprobs = chosen_token_logprobs.detach()
29
-
30
- # Compute advantages per batch entry in a grouped fashion
31
- mean_grouped_rewards = rewards.mean() # [batch_size,]
32
- std_grouped_rewards = rewards.std() # [batch_size,]
33
-
34
- # Calculate advantages using the same epsilon as in GRPOTrainer
35
- eps = 1e-4
36
- advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
40
+ if ref_per_token_logps is None:
41
+ if ref_log_probs is not None:
42
+ with torch.no_grad():
43
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
44
+ -1
45
+ )
46
+ else:
47
+ ref_per_token_logps = per_token_logps.detach()
37
48
 
38
49
  # Compute policy gradient loss with importance sampling ratio
39
- ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
- policy_loss = -ratio * advantages.unsqueeze(1)
41
-
42
- # Compute KL penalty
43
- kl_div = (
44
- torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
- )
50
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
51
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
52
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
53
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
54
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
55
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
56
+ if beta != 0.0:
57
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
58
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
59
+ # Combine losses
60
+ per_token_loss = per_token_loss + beta * kl_div
46
61
 
47
- # Combine losses
48
- per_token_loss = policy_loss + beta * kl_div
49
-
50
- # Apply masking and normalize
51
- masked_loss = per_token_loss * attention_mask
52
- seq_lengths = attention_mask.sum()
53
- seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
- loss = masked_loss.sum() / seq_lengths
62
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
63
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
64
+ # and TRL GRPO implementation
65
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
66
+ if loss_type == "grpo":
67
+ # Average per-sequence loss
68
+ loss = (
69
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
70
+ ).sum() / full_attention_mask.shape[0]
71
+ elif loss_type == "bnpo":
72
+ # Batch Normalized Per-token loss (original implementation)
73
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
74
+ elif loss_type == "dr_grpo":
75
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
76
+ if max_completion_length is None:
77
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
78
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
79
+ else:
80
+ raise ValueError(f"Unknown loss type: {loss_type}")
55
81
 
56
82
  # Calculate metrics
57
- metrics = (
58
- chosen_token_logprobs.mean(), # mean log prob
59
- chosen_token_logprobs.std(), # std log prob
60
- log_probs.mean(), # mean all log probs
61
- ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
83
+ metrics = []
84
+ if beta != 0.0:
85
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
86
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
87
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
62
88
  )
63
-
89
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
64
90
  return loss, metrics
65
91
 
66
92
  @classmethod
@@ -69,16 +95,23 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
69
95
  ctx,
70
96
  _input,
71
97
  weight,
98
+ selected_token_ids,
72
99
  attention_mask,
73
- rewards,
100
+ advantages,
74
101
  bias=None,
102
+ ref_per_token_logps=None,
103
+ old_per_token_logps=None,
75
104
  ref_input=None,
76
105
  ref_weight=None,
77
106
  ref_bias=None,
78
- beta=0.1,
107
+ beta=0.04,
108
+ epsilon_low=0.2,
109
+ epsilon_high=0.2,
110
+ loss_type="bnpo",
111
+ max_completion_length=None,
112
+ temperature=1.0,
79
113
  compiled=True,
80
114
  use_ref_model=True,
81
- num_generations=1,
82
115
  chunk_size=1,
83
116
  ):
84
117
  """
@@ -86,16 +119,20 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
86
119
  Args:
87
120
  _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
88
121
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
122
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
89
123
  attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
90
- rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
124
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
91
125
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
126
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
92
127
  ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
93
128
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
94
129
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
95
130
  beta (float): Weight for the KL penalty
131
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
132
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
133
+ temperature (float): Temperature for the logits
96
134
  compiled (bool): Whether to use torch compile
97
135
  use_ref_model (bool): Whether to use a reference model
98
- num_generations (int): Number of generations per prompt
99
136
  chunk_size (int): Size of chunks for processing.
100
137
  Returns:
101
138
  torch.Tensor: Computed loss
@@ -105,16 +142,23 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
105
142
  ctx=ctx,
106
143
  _input=_input,
107
144
  weight=weight,
145
+ selected_token_ids=selected_token_ids,
108
146
  attention_mask=attention_mask,
109
- rewards=rewards,
147
+ advantages=advantages,
110
148
  bias=bias,
149
+ ref_per_token_logps=ref_per_token_logps,
150
+ old_per_token_logps=old_per_token_logps,
111
151
  ref_input=ref_input,
112
152
  ref_weight=ref_weight,
113
153
  ref_bias=ref_bias,
114
154
  beta=beta,
155
+ epsilon_low=epsilon_low,
156
+ epsilon_high=epsilon_high,
157
+ loss_type=loss_type,
158
+ max_completion_length=max_completion_length,
159
+ temperature=temperature,
115
160
  compiled=compiled,
116
161
  use_ref_model=use_ref_model,
117
- num_generations=num_generations,
118
162
  chunk_size=chunk_size,
119
163
  )
120
164
 
@@ -126,16 +170,24 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
126
170
  grad_output: Gradient of the loss (scalar)
127
171
  grad_metrics: Gradients of the metrics (not used in backward computation)
128
172
  """
129
- grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
173
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
130
174
  return (
131
- *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
175
+ *grads[
176
+ :6
177
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
178
+ None, # grad_ref_per_token_logps
179
+ None, # grad_old_per_token_logps
132
180
  None, # grad_ref_input
133
181
  None, # grad_ref_weight
134
182
  None, # grad_ref_bias
135
183
  None, # grad_beta
184
+ None, # grad_epsilon_low
185
+ None, # grad_epsilon_high
186
+ None, # grad_loss_type (string, not differentiable)
187
+ None, # grad_max_completion_length (int, not differentiable)
188
+ None, # grad_temperature
136
189
  None, # grad_compiled
137
190
  None, # grad_use_ref_model
138
- None, # grad_num_generations
139
191
  None, # grad_chunk_size
140
192
  )
141
193
 
@@ -145,34 +197,49 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
145
197
 
146
198
  def __init__(
147
199
  self,
148
- beta: float = 0.1,
200
+ beta: float = 0.04,
149
201
  compiled: bool = True,
150
202
  use_ref_model: bool = True,
151
- num_generations: int = 1,
152
203
  chunk_size: int = 1,
204
+ epsilon_low: float = 0.2,
205
+ epsilon_high: float = 0.2,
206
+ loss_type: str = "bnpo",
207
+ max_completion_length: int | None = None,
208
+ temperature: float = 1.0,
153
209
  ):
154
210
  """
155
211
  Args:
156
212
  beta (float): Weight for the KL penalty.
157
213
  compiled (bool): Whether to use torch compile.
158
214
  use_ref_model (bool): Whether to use a reference model.
159
- num_generations (int): Number of generations per prompt.
160
215
  chunk_size (int): Size of chunks for processing.
216
+ epsilon_low (float): Lower bound for the importance sampling ratio.
217
+ epsilon_high (float): Upper bound for the importance sampling ratio.
218
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
219
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
220
+ temperature (float): Temperature for the logits.
161
221
  """
162
222
  super().__init__()
163
223
  self.beta = beta
164
224
  self.compiled = compiled
165
225
  self.use_ref_model = use_ref_model
166
- self.num_generations = num_generations
167
226
  self.chunk_size = chunk_size
227
+ self.epsilon_low = epsilon_low
228
+ self.epsilon_high = epsilon_high
229
+ self.loss_type = loss_type
230
+ self.max_completion_length = max_completion_length
231
+ self.temperature = temperature
168
232
 
169
233
  def forward(
170
234
  self,
171
235
  _input,
172
236
  lin_weight,
237
+ selected_token_ids,
173
238
  attention_mask,
174
- rewards,
239
+ advantages,
175
240
  bias=None,
241
+ ref_per_token_logps=None,
242
+ old_per_token_logps=None,
176
243
  ref_input=None,
177
244
  ref_weight=None,
178
245
  ref_bias=None,
@@ -180,15 +247,22 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
180
247
  return LigerFusedLinearGRPOFunction.apply(
181
248
  _input,
182
249
  lin_weight,
250
+ selected_token_ids,
183
251
  attention_mask,
184
- rewards,
252
+ advantages,
185
253
  bias,
254
+ ref_per_token_logps,
255
+ old_per_token_logps,
186
256
  ref_input,
187
257
  ref_weight,
188
258
  ref_bias,
189
259
  self.beta,
260
+ self.epsilon_low,
261
+ self.epsilon_high,
262
+ self.loss_type,
263
+ self.max_completion_length,
264
+ self.temperature,
190
265
  self.compiled,
191
266
  self.use_ref_model,
192
- self.num_generations,
193
267
  self.chunk_size,
194
268
  )
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
19
19
  student_log_probs = F.log_softmax(student_logits, dim=-1)
20
20
  teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
21
 
22
- # Compute probabilities (only required for mean calculation)
23
- mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
- log_mean_probs = mean_probs.log()
22
+ if beta == 0:
23
+ jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
24
+ elif beta == 1:
25
+ jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
+ else:
27
+ # Compute probabilities (only required for mean calculation)
28
+ mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
+ log_mean_probs = mean_probs.log()
25
30
 
26
- student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
- teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
31
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
33
 
29
- # JSD is the weighted average of the KL divergences
30
- jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
34
+ # JSD is the weighted average of the KL divergences
35
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
36
  return jsd_loss
32
37
 
33
38
  @classmethod
@@ -9,6 +9,7 @@ import triton.language as tl
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
12
13
 
13
14
  if compare_version("triton", operator.ge, "3.0.0"):
14
15
  try:
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
59
60
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
60
61
  loss_stride (int): The stride of the loss tensor.
61
62
  n_cols (int): The number of columns in the input tensor.
62
- n_non_ignore (flaot): The number of non-ignored elements in the batch.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
63
64
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
64
65
  weight_sum (float): The sum of weight tensor.
65
66
  ignore_index (int): The index to ignore in the target.
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
258
259
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259
260
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260
261
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
261
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
262
263
 
263
264
 
264
265
  def cross_entropy_forward(
@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
8
8
  from liger_kernel.ops.utils import amp_custom_fwd
9
9
  from liger_kernel.ops.utils import element_mul_kernel
10
10
  from liger_kernel.ops.utils import is_hip
11
+ from liger_kernel.utils import infer_device
11
12
 
12
13
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
14
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
14
15
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
15
- MAX_FUSED_SIZE = 65536 // 2
16
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
16
17
 
17
18
 
18
19
  def fused_linear_jsd_forward(
liger_kernel/ops/jsd.py CHANGED
@@ -5,6 +5,7 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.utils import infer_device
8
9
 
9
10
 
10
11
  @triton.jit
@@ -51,29 +52,48 @@ def _jsd_kernel(
51
52
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
53
 
53
54
  if beta == 0.0: # forward KL
54
- Y_prob = tl.exp(Y)
55
+ Y_max = tl.max(Y, axis=0)
56
+ Y_shifted = Y - Y_max
57
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
55
58
  loss = Y_prob * (Y - X)
56
59
  dX = -Y_prob
57
- elif beta == 1.0:
58
- X_prob = tl.exp(X)
60
+ elif beta == 1.0: # reverse KL
61
+ X_max = tl.max(X, axis=0)
62
+ X_shifted = X - X_max
63
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
59
64
  loss = X_prob * (X - Y)
60
65
  dX = loss + X_prob
61
66
  else:
62
- Q = tl.exp(X)
63
- P = tl.exp(Y)
64
- M = beta * P + (1 - beta) * Q
65
- log_M = tl.log(M)
67
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
68
+ X_shifted = X - max_val
69
+ Y_shifted = Y - max_val
66
70
 
67
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
- dX = (1 - beta) * Q * (X - log_M)
71
+ # Pre-compute exp(max_val) since it's used twice
72
+ exp_max = tl.exp(max_val)
73
+
74
+ # Compute exp terms with compensation
75
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
76
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
77
+
78
+ # Pre-compute common terms
79
+ beta_P = beta * P
80
+ one_minus_beta_Q = (1 - beta) * Q
81
+ M = beta_P + one_minus_beta_Q
82
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
83
+
84
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
85
+ dX = one_minus_beta_Q * (X - log_M)
86
+
87
+ # Pre-compute scaling factor
88
+ scale = 1.0 / n_non_ignore
89
+ loss = loss * scale
90
+ dX = dX * scale
69
91
 
70
- loss = loss / n_non_ignore
71
- dX = dX / n_non_ignore
72
92
  tl.store(loss_ptr + offsets, loss, mask=mask)
73
93
  tl.store(dX_ptr + offsets, dX, mask=mask)
74
94
 
75
95
 
76
- MAX_FUSED_SIZE = 65536
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
77
97
 
78
98
 
79
99
  def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):