liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +51 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +30 -4
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
  5. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  6. liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  7. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  8. liger_kernel/chunked_loss/grpo_loss.py +137 -61
  9. liger_kernel/chunked_loss/jsd_loss.py +43 -13
  10. liger_kernel/chunked_loss/kto_loss.py +50 -12
  11. liger_kernel/chunked_loss/orpo_loss.py +37 -5
  12. liger_kernel/chunked_loss/simpo_loss.py +47 -11
  13. liger_kernel/ops/cross_entropy.py +7 -2
  14. liger_kernel/ops/dyt.py +225 -0
  15. liger_kernel/ops/fused_linear_jsd.py +2 -1
  16. liger_kernel/ops/jsd.py +30 -11
  17. liger_kernel/ops/kl_div.py +2 -2
  18. liger_kernel/transformers/__init__.py +4 -0
  19. liger_kernel/transformers/dyt.py +20 -0
  20. liger_kernel/transformers/functional.py +5 -0
  21. liger_kernel/transformers/model/gemma.py +8 -16
  22. liger_kernel/transformers/model/gemma2.py +7 -16
  23. liger_kernel/transformers/model/llama.py +8 -15
  24. liger_kernel/transformers/model/llava.py +369 -0
  25. liger_kernel/transformers/model/loss_utils.py +57 -0
  26. liger_kernel/transformers/model/mistral.py +9 -10
  27. liger_kernel/transformers/model/mixtral.py +8 -15
  28. liger_kernel/transformers/model/mllama.py +8 -15
  29. liger_kernel/transformers/model/olmo2.py +8 -16
  30. liger_kernel/transformers/model/paligemma.py +397 -0
  31. liger_kernel/transformers/model/phi3.py +8 -15
  32. liger_kernel/transformers/model/qwen2.py +8 -15
  33. liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
  34. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  35. liger_kernel/transformers/monkey_patch.py +286 -12
  36. liger_kernel/utils.py +1 -3
  37. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
  38. liger_kernel-0.5.6.dist-info/RECORD +80 -0
  39. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  40. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
  41. liger_kernel-0.5.4.dist-info/RECORD +0 -74
  42. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  43. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  44. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
7
7
  class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
8
  @staticmethod
9
9
  def preference_loss_fn(
10
- average_log_prob_chunk,
10
+ log_prob_chunk,
11
11
  preference_labels_chunk,
12
12
  full_target,
13
- ref_average_log_prob_chunk=None,
13
+ ref_log_prob_chunk=None,
14
14
  beta=0.1,
15
15
  kl=None,
16
16
  ):
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
46
+ log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
47
  preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
49
+ ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
50
  beta: Weight for the KTO loss
51
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
52
52
  Returns:
53
53
  - loss: The KTO loss value
54
54
  """
55
- if ref_average_log_prob_chunk is not None:
56
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_log_prob_chunk is not None:
56
+ logratios_chunk = log_prob_chunk - ref_log_prob_chunk
57
57
  else:
58
- logratios_chunk = average_log_prob_chunk
59
-
58
+ logratios_chunk = log_prob_chunk
60
59
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
60
  if kl is not None:
62
61
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
63
62
  else:
64
63
  losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
65
64
 
66
- return losses.sum() / (full_target.shape[0])
65
+ rewards = beta * logratios_chunk
66
+ chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
67
+ rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
67
68
 
68
- @staticmethod
69
+ return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
70
+
71
+ @classmethod
69
72
  def forward(
73
+ cls,
70
74
  ctx,
71
75
  _input,
72
76
  weight,
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
81
85
  beta=0.1,
82
86
  compiled=True,
83
87
  use_ref_model=True,
88
+ average_log_prob=False,
89
+ chunk_size=1,
84
90
  ):
85
- return LigerFusedLinearUnpairedPreferenceBase.forward(
91
+ """
92
+ Fused linear layer with KTO loss.
93
+ Args:
94
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
95
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
96
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
97
+ preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
98
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
99
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
100
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
101
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
102
+ kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
103
+ ignore_index (int): Index to ignore in loss computation
104
+ beta (float): Temperature parameter for the KTO loss
105
+ compiled (bool): Whether to use torch compile
106
+ use_ref_model (bool): Whether to use a reference model
107
+ average_log_prob (bool): Whether to average the log probability per non-masked token
108
+ chunk_size (int): Size of chunks for processing
109
+ Returns:
110
+ torch.Tensor: Computed loss
111
+ """
112
+ return super().forward(
113
+ cls=cls,
86
114
  ctx=ctx,
87
115
  _input=_input,
88
116
  weight=weight,
89
117
  target=target,
90
118
  preference_labels=preference_labels,
91
119
  bias=bias,
92
- loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
93
120
  ignore_index=ignore_index,
94
121
  beta=beta,
95
122
  compiled=compiled,
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
97
124
  ref_input=ref_input,
98
125
  ref_weight=ref_weight,
99
126
  ref_bias=ref_bias,
127
+ average_log_prob=average_log_prob,
100
128
  kl=kl,
129
+ chunk_size=chunk_size,
101
130
  )
102
131
 
103
132
  @staticmethod
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
115
144
  None,
116
145
  None,
117
146
  None,
147
+ None,
118
148
  )
119
149
 
120
150
 
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
129
159
  beta: float = 0.1,
130
160
  compiled: bool = True,
131
161
  use_ref_model: bool = False,
162
+ average_log_prob: bool = False,
163
+ chunk_size: int = 1,
132
164
  ):
133
165
  """
134
166
  Args:
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
136
168
  beta (float): Temperature parameter for the KTO loss
137
169
  compiled (bool): Whether to use compiled operations
138
170
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
171
+ average_log_prob (bool): Whether to average the log probability per non-masked token
172
+ chunk_size (int): Size of chunks for processing
139
173
  """
140
174
  super().__init__()
141
175
  self.ignore_index = ignore_index
142
176
  self.beta = beta
143
177
  self.compiled = compiled
144
178
  self.use_ref_model = use_ref_model
179
+ self.average_log_prob = average_log_prob
180
+ self.chunk_size = chunk_size
145
181
 
146
182
  def forward(
147
183
  self,
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
169
205
  self.beta,
170
206
  self.compiled,
171
207
  self.use_ref_model,
208
+ self.average_log_prob,
209
+ self.chunk_size,
172
210
  )
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
42
42
 
43
43
  return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
44
44
 
45
- @staticmethod
45
+ @classmethod
46
46
  def forward(
47
+ cls,
47
48
  ctx,
48
49
  _input,
49
50
  weight,
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
54
55
  compute_nll_loss=True,
55
56
  nll_target=None,
56
57
  compiled=True,
58
+ chunk_size=1,
57
59
  ):
58
- return LigerFusedLinearPreferenceBase.forward(
60
+ """
61
+ Fused linear layer with ORPO loss.
62
+ Args:
63
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
64
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
65
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
66
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
67
+ ignore_index (int): Index to ignore in loss computation
68
+ beta (float): Weight for the odds ratio loss
69
+ compute_nll_loss (bool): Whether to compute the NLL loss
70
+ nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
71
+ compiled (bool): Whether to use torch compile
72
+ chunk_size (int): Size of chunks for processing
73
+ Returns:
74
+ torch.Tensor: Computed loss
75
+ """
76
+ return super().forward(
77
+ cls=cls,
59
78
  ctx=ctx,
60
79
  _input=_input,
61
80
  weight=weight,
62
81
  target=target,
63
82
  bias=bias,
64
- loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
65
83
  ignore_index=ignore_index,
66
84
  beta=beta,
67
85
  compute_nll_loss=compute_nll_loss,
68
86
  nll_target=nll_target,
69
87
  compiled=compiled,
88
+ chunk_size=chunk_size,
70
89
  )
71
90
 
72
91
  @staticmethod
73
92
  def backward(ctx, *grad_output):
74
93
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None
94
+ return *grads, None, None, None, None, None, None
76
95
 
77
96
 
78
97
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
86
105
  beta: float = 0.1,
87
106
  compute_nll_loss: bool = True,
88
107
  compiled: bool = True,
108
+ chunk_size: int = 1,
89
109
  ):
90
110
  """
91
111
  Args:
92
112
  ignore_index (int): Index to ignore in the loss.
93
113
  beta (float): Weight for the odds ratio loss.
114
+ compute_nll_loss (bool): Whether to compute the NLL loss.
115
+ compiled (bool): Whether to use the torch compiled kernel.
116
+ chunk_size (int): Size of chunks for processing.
94
117
  """
95
118
  super().__init__()
96
119
  self.ignore_index = ignore_index
97
120
  self.beta = beta
98
121
  self.compute_nll_loss = compute_nll_loss
99
122
  self.compiled = compiled
123
+ self.chunk_size = chunk_size
100
124
 
101
- def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
125
+ def forward(
126
+ self,
127
+ lin_weight,
128
+ _input,
129
+ target,
130
+ bias=None,
131
+ nll_target=None,
132
+ ):
102
133
  return LigerFusedLinearORPOFunction.apply(
103
134
  _input,
104
135
  lin_weight,
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
109
140
  self.compute_nll_loss,
110
141
  nll_target,
111
142
  self.compiled,
143
+ self.chunk_size,
112
144
  )
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
47
47
 
48
48
  return loss, chosen_rewards, rejected_rewards
49
49
 
50
- @staticmethod
50
+ @classmethod
51
51
  def forward(
52
+ cls,
52
53
  ctx,
53
54
  _input,
54
55
  weight,
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
62
  compute_nll_loss=False,
62
63
  compiled=True,
63
64
  gamma=0.5,
65
+ chunk_size=1,
64
66
  ):
65
- return LigerFusedLinearPreferenceBase.forward(
66
- ctx,
67
- _input,
68
- weight,
69
- target,
70
- bias,
71
- loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
72
- compute_nll_loss=compute_nll_loss,
67
+ """
68
+ Fused linear layer with SimPO loss.
69
+ Args:
70
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
71
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
72
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
73
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
74
+ ignore_index (int): Index to ignore in loss computation
75
+ beta (float): Weight for the odds ratio loss
76
+ alpha (float): Weight for the alpha parameter
77
+ label_smoothing (float): Label smoothing factor
78
+ compute_nll_loss (bool): Whether to compute the NLL loss
79
+ compiled (bool): Whether to use torch compile
80
+ gamma (float): Weight for the gamma parameter
81
+ chunk_size (int): Size of chunks for processing
82
+ Returns:
83
+ torch.Tensor: Computed loss
84
+ """
85
+ return super().forward(
86
+ cls=cls,
87
+ ctx=ctx,
88
+ _input=_input,
89
+ weight=weight,
90
+ target=target,
91
+ bias=bias,
73
92
  ignore_index=ignore_index,
74
93
  alpha=alpha,
75
94
  beta=beta,
76
95
  label_smoothing=label_smoothing,
96
+ compute_nll_loss=compute_nll_loss,
77
97
  compiled=compiled,
78
98
  gamma=gamma,
99
+ chunk_size=chunk_size,
79
100
  )
80
101
 
81
102
  @staticmethod
82
103
  def backward(ctx, *grad_output):
83
104
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
84
- return *grads, None, None, None, None, None, None, None
105
+ return *grads, None, None, None, None, None, None, None, None
85
106
 
86
107
 
87
108
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
98
119
  compute_nll_loss: bool = True,
99
120
  compiled: bool = True,
100
121
  gamma: float = 0.5,
122
+ chunk_size: int = 1,
101
123
  ):
102
124
  """
103
125
  Args:
104
126
  ignore_index (int): Index to ignore in the loss.
105
127
  beta (float): Weight for the odds ratio loss.
128
+ alpha (float): Weight for the alpha parameter.
129
+ label_smoothing (float): Label smoothing factor.
130
+ compute_nll_loss (bool): Whether to compute the NLL loss.
131
+ compiled (bool): Whether to use the torch compiled kernel.
132
+ gamma (float): Weight for the gamma parameter.
133
+ chunk_size (int): Size of chunks for processing.
106
134
  """
107
135
  super().__init__()
108
136
  self.ignore_index = ignore_index
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
112
140
  self.compute_nll_loss = compute_nll_loss
113
141
  self.compiled = compiled
114
142
  self.gamma = gamma
143
+ self.chunk_size = chunk_size
115
144
 
116
- def forward(self, lin_weight, _input, target, bias=None):
145
+ def forward(
146
+ self,
147
+ lin_weight,
148
+ _input,
149
+ target,
150
+ bias=None,
151
+ ):
117
152
  return LigerFusedLinearSimPOFunction.apply(
118
153
  _input,
119
154
  lin_weight,
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
126
161
  self.compute_nll_loss,
127
162
  self.compiled,
128
163
  self.gamma,
164
+ self.chunk_size,
129
165
  )
@@ -9,6 +9,7 @@ import triton.language as tl
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
12
13
 
13
14
  if compare_version("triton", operator.ge, "3.0.0"):
14
15
  try:
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
59
60
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
60
61
  loss_stride (int): The stride of the loss tensor.
61
62
  n_cols (int): The number of columns in the input tensor.
62
- n_non_ignore (flaot): The number of non-ignored elements in the batch.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
63
64
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
64
65
  weight_sum (float): The sum of weight tensor.
65
66
  ignore_index (int): The index to ignore in the target.
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
258
259
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259
260
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260
261
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
261
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
262
263
 
263
264
 
264
265
  def cross_entropy_forward(
@@ -285,6 +286,10 @@ def cross_entropy_forward(
285
286
 
286
287
  target_mask = target != ignore_index
287
288
  n_non_ignore = target_mask.sum().item()
289
+ assert (target * target_mask).max() < _input.shape[-1], (
290
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
291
+ )
292
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
288
293
  sum_non_ignore_weight = n_non_ignore
289
294
  weight_sum = 0.0
290
295
  if weight is not None:
@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
8
8
  from liger_kernel.ops.utils import amp_custom_fwd
9
9
  from liger_kernel.ops.utils import element_mul_kernel
10
10
  from liger_kernel.ops.utils import is_hip
11
+ from liger_kernel.utils import infer_device
11
12
 
12
13
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
14
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
14
15
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
15
- MAX_FUSED_SIZE = 65536 // 2
16
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
16
17
 
17
18
 
18
19
  def fused_linear_jsd_forward(
liger_kernel/ops/jsd.py CHANGED
@@ -51,24 +51,43 @@ def _jsd_kernel(
51
51
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
52
 
53
53
  if beta == 0.0: # forward KL
54
- Y_prob = tl.exp(Y)
54
+ Y_max = tl.max(Y, axis=0)
55
+ Y_shifted = Y - Y_max
56
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
55
57
  loss = Y_prob * (Y - X)
56
58
  dX = -Y_prob
57
- elif beta == 1.0:
58
- X_prob = tl.exp(X)
59
+ elif beta == 1.0: # reverse KL
60
+ X_max = tl.max(X, axis=0)
61
+ X_shifted = X - X_max
62
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
59
63
  loss = X_prob * (X - Y)
60
64
  dX = loss + X_prob
61
65
  else:
62
- Q = tl.exp(X)
63
- P = tl.exp(Y)
64
- M = beta * P + (1 - beta) * Q
65
- log_M = tl.log(M)
66
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
67
+ X_shifted = X - max_val
68
+ Y_shifted = Y - max_val
66
69
 
67
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
- dX = (1 - beta) * Q * (X - log_M)
70
+ # Pre-compute exp(max_val) since it's used twice
71
+ exp_max = tl.exp(max_val)
72
+
73
+ # Compute exp terms with compensation
74
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
75
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
76
+
77
+ # Pre-compute common terms
78
+ beta_P = beta * P
79
+ one_minus_beta_Q = (1 - beta) * Q
80
+ M = beta_P + one_minus_beta_Q
81
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
82
+
83
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
84
+ dX = one_minus_beta_Q * (X - log_M)
85
+
86
+ # Pre-compute scaling factor
87
+ scale = 1.0 / n_non_ignore
88
+ loss = loss * scale
89
+ dX = dX * scale
69
90
 
70
- loss = loss / n_non_ignore
71
- dX = dX / n_non_ignore
72
91
  tl.store(loss_ptr + offsets, loss, mask=mask)
73
92
  tl.store(dX_ptr + offsets, dX, mask=mask)
74
93