liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +8 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/cross_entropy.py +4 -1
  5. liger_kernel/ops/dyt.py +113 -179
  6. liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  7. liger_kernel/ops/grpo_loss.py +310 -0
  8. liger_kernel/ops/sparsemax.py +167 -0
  9. liger_kernel/transformers/__init__.py +11 -0
  10. liger_kernel/transformers/dyt.py +5 -3
  11. liger_kernel/transformers/fsdp.py +55 -0
  12. liger_kernel/transformers/functional.py +8 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  14. liger_kernel/transformers/grpo_loss.py +98 -0
  15. liger_kernel/transformers/model/gemma.py +8 -12
  16. liger_kernel/transformers/model/gemma2.py +8 -10
  17. liger_kernel/transformers/model/gemma3.py +3 -9
  18. liger_kernel/transformers/model/glm4.py +119 -0
  19. liger_kernel/transformers/model/llama.py +64 -15
  20. liger_kernel/transformers/model/llava.py +0 -8
  21. liger_kernel/transformers/model/mistral.py +8 -10
  22. liger_kernel/transformers/model/mixtral.py +8 -12
  23. liger_kernel/transformers/model/mllama.py +8 -11
  24. liger_kernel/transformers/model/olmo2.py +8 -10
  25. liger_kernel/transformers/model/paligemma.py +0 -8
  26. liger_kernel/transformers/model/phi3.py +8 -12
  27. liger_kernel/transformers/model/qwen2.py +8 -12
  28. liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
  29. liger_kernel/transformers/model/qwen2_vl.py +3 -7
  30. liger_kernel/transformers/model/qwen3.py +112 -0
  31. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  32. liger_kernel/transformers/monkey_patch.py +243 -13
  33. liger_kernel/transformers/sparsemax.py +16 -0
  34. liger_kernel/transformers/swiglu.py +21 -0
  35. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  36. liger_kernel/utils.py +11 -0
  37. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
  38. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
  39. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  40. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  41. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  42. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
68
  compute_nll_loss=False,
69
69
  compiled=True,
70
70
  use_ref_model=True,
71
+ average_log_prob=False,
71
72
  chunk_size=1,
72
73
  ):
73
74
  """
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
86
  compute_nll_loss (bool): Whether to compute the NLL loss
86
87
  compiled (bool): Whether to use torch compile
87
88
  use_ref_model (bool): Whether to use a reference model
89
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
90
  chunk_size (int): Size of chunks for processing.
89
91
  Returns:
90
92
  torch.Tensor: Computed loss
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
106
  ref_input=ref_input,
105
107
  ref_weight=ref_weight,
106
108
  ref_bias=ref_bias,
109
+ average_log_prob=average_log_prob,
107
110
  chunk_size=chunk_size,
108
111
  )
109
112
 
110
113
  @staticmethod
111
114
  def backward(ctx, *grad_output):
112
115
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
116
+ return *grads, None, None, None, None, None, None, None, None, None, None
114
117
 
115
118
 
116
119
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
128
  compute_nll_loss: bool = False,
126
129
  compiled: bool = True,
127
130
  use_ref_model: bool = True,
131
+ average_log_prob: bool = False,
128
132
  chunk_size: int = 1,
129
133
  ):
130
134
  """
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
138
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
139
  compiled (bool): Whether to use the torch compiled kernel.
136
140
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
141
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
142
  chunk_size (int): Size of chunks for processing.
138
143
  """
139
144
  super().__init__()
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
147
  self.compute_nll_loss = compute_nll_loss
143
148
  self.compiled = compiled
144
149
  self.use_ref_model = use_ref_model
150
+ self.average_log_prob = average_log_prob
145
151
  self.chunk_size = chunk_size
146
152
 
147
153
  def forward(
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
173
  self.compute_nll_loss,
168
174
  self.compiled,
169
175
  self.use_ref_model,
176
+ self.average_log_prob,
170
177
  self.chunk_size,
171
178
  )
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
224
  (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
- strict=False,
226
225
  ):
227
226
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
228
227
  ref_input_chunk = (
@@ -150,8 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
150
150
  teacher_input: torch.Tensor,
151
151
  teacher_weight: torch.Tensor,
152
152
  true_labels: torch.LongTensor,
153
- student_bias: torch.Tensor,
154
- teacher_bias: torch.Tensor,
153
+ student_bias: torch.Tensor = None,
154
+ teacher_bias: torch.Tensor = None,
155
155
  ) -> torch.Tensor:
156
156
  """
157
157
  Compute the JSD distillation loss.
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
351
351
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
352
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
353
  pass
354
-
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
358
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
359
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
360
  else:
liger_kernel/ops/dyt.py CHANGED
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import calculate_settings
7
+ from triton.language.extra.libdevice import tanh
8
+
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
10
11
  from liger_kernel.ops.utils import infer_device
@@ -20,187 +21,126 @@ else:
20
21
  from triton.language.math import tanh
21
22
 
22
23
 
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
23
30
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
58
41
 
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
59
43
 
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
60
58
  @triton.jit
61
59
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
77
61
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
133
88
 
134
89
 
135
90
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
140
97
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
153
114
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
166
125
  device = infer_device()
167
126
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
169
128
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+
131
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
134
+ dx = torch.empty_like(dy)
135
+
136
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
137
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
138
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
139
+ if HAVE_BETA:
140
+ db = db.sum(0).to(x.dtype)
141
+ dg = dg.sum(0).to(gamma.dtype)
142
+ da = da.sum().to(x.dtype).unsqueeze(0)
143
+ return dx.view(input_shape), da, dg, db
204
144
 
205
145
 
206
146
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +148,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
148
  @ensure_contiguous
209
149
  def forward(ctx, x, alpha, gamma, beta):
210
150
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
151
+ ctx.save_for_backward(x, alpha, gamma, beta)
212
152
  return y
213
153
 
214
154
  @staticmethod
215
155
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
156
+ def backward(ctx, dy):
157
+ x, alpha, gamma, beta = ctx.saved_tensors
158
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
159
+ return dx, dalpha, dgamma, dbeta
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
143
143
  alpha=1.0,
144
144
  )
145
145
 
146
- if reduction == "none":
147
- loss = loss_1d
148
- z_loss = z_loss_1d if return_z_loss else None
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
149
150
 
150
151
  else:
151
152
  loss = torch.sum(loss_1d)