liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -211,9 +195,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
211
195
  """
212
196
  has_label = False
213
197
  if shift_labels is not None:
214
- assert shift_labels.shape == (
215
- teacher_input.shape[0],
216
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
198
+ assert shift_labels.shape == (teacher_input.shape[0],), (
199
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
200
+ )
217
201
  shift_labels = shift_labels.contiguous()
218
202
  has_label = True
219
203
 
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import compare_version, ensure_contiguous
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
8
9
 
9
10
  if compare_version("triton", operator.ge, "3.0.0"):
10
11
  try:
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
73
74
 
74
75
  # Normalize
75
76
  hidden_size_per_channel = hidden_size // channels_per_group
76
- for channel_idx in tl.range(
77
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
- ):
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
78
  W = tl.load(W_ptr + channel_idx)
80
79
  B = tl.load(B_ptr + channel_idx)
81
80
  for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
132
131
  UPSTREAM_ptr += batch_idx * X_row_stride
133
132
 
134
133
  # Mean and rstd are the same shape so have the same strides
135
- mean = tl.load(
136
- Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
- )
138
- rstd = tl.load(
139
- RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
- )
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
141
136
 
142
137
  c1 = 0.0
143
138
  c2 = 0.0
144
139
  block_range = tl.arange(0, BLOCK_SIZE)
145
140
 
146
141
  # We need to compute the sum terms of the backprop equations across all channels in the group
147
- for channel_idx in range(
148
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
- ):
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
150
143
  dW = 0.0
151
144
  dB = 0.0
152
145
  # Move the pointers to the correct channel
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
181
174
  c1 = c1 / N
182
175
  c2 = c2 / N
183
176
 
184
- for channel_idx in tl.range(
185
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
- ):
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
187
178
  # Move the pointers to the correct channel
188
179
  W = tl.load(W_ptr + channel_idx)
189
180
  for i in range(0, hidden_size, BLOCK_SIZE):
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
203
194
  x_hat = (X - mean) * rstd
204
195
  wdy = W * UPSTREAM_grad
205
196
  dx = (wdy - (x_hat * c1 + c2)) * rstd
206
- tl.store(
207
- DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
- )
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
209
198
 
210
199
 
211
200
  def group_norm_forward(X, num_channels, num_groups, W, B, eps):
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
216
205
  X = X.view(batch_size, num_groups, -1).contiguous()
217
206
  hidden_size = X.shape[-1]
218
207
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
- Y = torch.empty(
220
- (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
- )
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
222
209
  Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
210
  RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
211
 
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
307
294
  )
308
295
  ctx.num_channels = num_channels
309
296
  ctx.num_groups = num_groups
310
- ctx.save_for_backward(
311
- X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
- )
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
313
298
  return Y
314
299
 
315
300
  @staticmethod
316
301
  @ensure_contiguous
317
302
  def backward(ctx, dY):
318
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
- DX, DW, DB = group_norm_backward(
320
- dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
- )
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
322
305
  return DX, DW, DB, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
98
98
  loss_stride=loss.stride(-2),
99
99
  dX_ptr=dX,
100
100
  dX_stride=dX.stride(-2),
101
- label_ptr=(
102
- shift_labels if has_label else torch.empty(1, device=_input.device)
103
- ), # dummy ptr if no label
101
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
104
102
  beta=beta,
105
103
  n_non_ignore=n_non_ignore,
106
104
  ignore_index=ignore_index,
@@ -159,15 +157,13 @@ class LigerJSDFunction(torch.autograd.Function):
159
157
  """
160
158
  has_label = False
161
159
  if shift_labels is not None:
162
- assert shift_labels.shape == (
163
- _input.shape[0],
164
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
160
+ assert shift_labels.shape == (_input.shape[0],), (
161
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
162
+ )
165
163
  shift_labels = shift_labels.contiguous()
166
164
  has_label = True
167
165
 
168
- loss, dX = jsd_forward(
169
- _input, target, shift_labels, beta, ignore_index, has_label
170
- )
166
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
171
167
  ctx.save_for_backward(dX)
172
168
  return loss
173
169
 
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
8
9
 
9
10
 
10
11
  def get_num_warps(BLOCK_SIZE):
@@ -23,10 +24,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
24
 
24
25
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
26
 
26
- _REDUCTION_MODE_NONE = tl.constexpr(0)
27
- _REDUCTION_MODE_SUM = tl.constexpr(1)
28
- _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
- _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
27
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
28
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
29
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
30
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
30
31
 
31
32
  _str_to_reduction_mode = {
32
33
  "none": _REDUCTION_MODE_NONE.value,
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
219
  ctx.save_for_backward(y_true)
219
220
  ctx.reduction = reduction
220
221
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
222
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
223
 
225
224
  @staticmethod
226
225
  @ensure_contiguous
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
237
 
239
238
  new_grads = torch.empty_like(y_true)
240
239
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
240
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
241
 
245
242
  if ctx.reduction == "batchmean":
246
243
  derivative = derivative / y_true.shape[0]
@@ -5,11 +5,9 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
13
11
 
14
12
  if compare_version("triton", operator.ge, "3.0.0"):
15
13
  try:
@@ -59,13 +57,14 @@ def _layer_norm_forward_kernel(
59
57
  B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
60
58
 
61
59
  mean = tl.sum(X_row, axis=0) / n_cols
62
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
63
62
  rstd = rsqrt(var + eps)
64
63
 
65
64
  tl.store(Mean_ptr, mean)
66
65
  tl.store(RSTD_ptr, rstd)
67
66
 
68
- Y_row = (X_row - mean) * rstd * W_row + B_row
67
+ Y_row = Xmm * rstd * W_row + B_row
69
68
 
70
69
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
71
70
 
@@ -149,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
149
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
150
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
151
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
152
- assert (
153
- X.shape[1] == W.shape[0]
154
- ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
155
156
 
156
157
  _layer_norm_forward_kernel[(n_rows,)](
157
158
  Y,
@@ -192,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
192
193
 
193
194
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
194
195
  if n_cols > BLOCK_SIZE:
195
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
196
+ raise RuntimeError(
197
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
+ )
196
199
 
197
200
  rows_per_program = math.ceil(n_rows / sm_count)
198
201
  grid = (sm_count,)
199
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
202
+ triton_dtype = (
203
+ tl.float32
204
+ if X.dtype == torch.float32
205
+ else tl.bfloat16
206
+ if X.dtype == torch.bfloat16
207
+ else tl.float16
208
+ if X.dtype == torch.float16
209
+ else tl.float32 # fallback to float32 for other types
210
+ )
200
211
  _layer_norm_backward_kernel[grid](
201
212
  X,
202
213
  W,
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
67
67
  # program instance (i.e. for the current token) separately
68
68
  # ####################################################################
69
69
  # left half of the head
70
- first_half_q_offsets = (
71
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
- )
73
- first_half_k_offsets = (
74
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
- )
76
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
- )
79
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
- )
82
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
- sin_row.dtype
84
- )
85
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
- sin_row.dtype
87
- )
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
88
76
 
89
77
  # right half of the head
90
78
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
79
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
80
  second_q_mask = first_q_mask
93
81
  second_k_mask = first_k_mask
94
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
- sin_row.dtype
96
- )
97
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
- sin_row.dtype
99
- )
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
100
84
 
101
85
  if not BACKWARD_PASS:
102
86
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
124
108
 
125
109
 
126
110
  def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
-
128
111
  # transpose it back to the physical shape because Triton looks at the physical storage
129
112
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
113
  q = q.transpose(1, 2)
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -35,9 +33,9 @@ else:
35
33
  from triton.language.math import rsqrt
36
34
 
37
35
 
38
- _CASTING_MODE_NONE = tl.constexpr(-1)
39
- _CASTING_MODE_LLAMA = tl.constexpr(0)
40
- _CASTING_MODE_GEMMA = tl.constexpr(1)
36
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
41
39
 
42
40
 
43
41
  @triton.jit
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
177
175
 
178
176
  dX_row = rstd_row * m
179
177
 
180
- dX_row += (rstd_row) * (
181
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
183
179
 
184
180
  # calculate the gradient of W
185
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
207
203
 
208
204
  def rms_norm_forward(X, W, eps, offset, casting_mode):
209
205
  if not isinstance(casting_mode, int):
210
- assert (
211
- casting_mode in _str_to_casting_mode
212
- ), f"Invalid casting mode: {casting_mode}"
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
213
207
  casting_mode = _str_to_casting_mode[casting_mode]
214
208
  else:
215
- assert (
216
- casting_mode in _str_to_casting_mode.values()
217
- ), f"Invalid casting mode: {casting_mode}"
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
218
210
 
219
211
  shape = X.shape
220
212
  dim = shape[-1]
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
225
217
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226
218
  # RSTD is to cache rstd for each row
227
219
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
- rstd_dtype = (
229
- torch.float32
230
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
- else X.dtype
232
- )
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
233
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234
222
 
235
223
  # Check constraints.
236
- assert (
237
- X.shape[1] == W.shape[0]
238
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239
225
 
240
226
  _rms_norm_forward_kernel[(n_rows,)](
241
227
  Y,
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
256
242
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257
243
 
258
244
 
259
- def rms_norm_backward(
260
- dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
- ):
245
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
262
246
  shape = dY.shape
263
247
  dim = shape[-1]
264
248
  dY = dY.view(-1, dim)
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
340
324
  X: (B, T, H) or (BxT, H)
341
325
  W: (H,)
342
326
  """
343
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344
- X, W, eps, offset, casting_mode
345
- )
327
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
346
328
  ctx.offset = offset
347
329
  ctx.casting_mode = casting_mode
348
330
  ctx.in_place = in_place