liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +2 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  10. liger_kernel/chunked_loss/kto_loss.py +172 -0
  11. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  12. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  13. liger_kernel/env_report.py +5 -12
  14. liger_kernel/ops/cross_entropy.py +102 -51
  15. liger_kernel/ops/experimental/embedding.py +1 -3
  16. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  17. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  18. liger_kernel/ops/fused_linear_jsd.py +11 -29
  19. liger_kernel/ops/geglu.py +6 -17
  20. liger_kernel/ops/group_norm.py +11 -28
  21. liger_kernel/ops/jsd.py +2 -6
  22. liger_kernel/ops/kl_div.py +8 -11
  23. liger_kernel/ops/layer_norm.py +3 -5
  24. liger_kernel/ops/qwen2vl_mrope.py +21 -37
  25. liger_kernel/ops/rms_norm.py +14 -32
  26. liger_kernel/ops/rope.py +31 -33
  27. liger_kernel/ops/swiglu.py +4 -8
  28. liger_kernel/ops/utils.py +2 -0
  29. liger_kernel/transformers/__init__.py +16 -24
  30. liger_kernel/transformers/auto_model.py +6 -13
  31. liger_kernel/transformers/cross_entropy.py +4 -6
  32. liger_kernel/transformers/experimental/embedding.py +1 -3
  33. liger_kernel/transformers/functional.py +11 -7
  34. liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
  35. liger_kernel/transformers/geglu.py +1 -4
  36. liger_kernel/transformers/group_norm.py +3 -9
  37. liger_kernel/transformers/jsd.py +1 -3
  38. liger_kernel/transformers/kl_div.py +1 -3
  39. liger_kernel/transformers/layer_norm.py +3 -9
  40. liger_kernel/transformers/model/gemma.py +18 -40
  41. liger_kernel/transformers/model/gemma2.py +19 -41
  42. liger_kernel/transformers/model/llama.py +22 -48
  43. liger_kernel/transformers/model/mistral.py +14 -26
  44. liger_kernel/transformers/model/mixtral.py +24 -54
  45. liger_kernel/transformers/model/mllama.py +16 -36
  46. liger_kernel/transformers/model/phi3.py +18 -40
  47. liger_kernel/transformers/model/qwen2.py +18 -40
  48. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  49. liger_kernel/transformers/monkey_patch.py +43 -117
  50. liger_kernel/transformers/qwen2vl_mrope.py +2 -2
  51. liger_kernel/transformers/rms_norm.py +4 -4
  52. liger_kernel/transformers/rope.py +2 -2
  53. liger_kernel/transformers/swiglu.py +2 -8
  54. liger_kernel/transformers/trainer/__init__.py +1 -3
  55. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  56. liger_kernel/triton/__init__.py +1 -3
  57. liger_kernel/triton/monkey_patch.py +1 -3
  58. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
  59. liger_kernel-0.5.3.dist-info/RECORD +69 -0
  60. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
  61. liger_kernel-0.5.1.dist-info/RECORD +0 -65
  62. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
  63. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
  64. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import compare_version, ensure_contiguous
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
8
9
 
9
10
  if compare_version("triton", operator.ge, "3.0.0"):
10
11
  try:
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
73
74
 
74
75
  # Normalize
75
76
  hidden_size_per_channel = hidden_size // channels_per_group
76
- for channel_idx in tl.range(
77
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
- ):
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
78
  W = tl.load(W_ptr + channel_idx)
80
79
  B = tl.load(B_ptr + channel_idx)
81
80
  for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
132
131
  UPSTREAM_ptr += batch_idx * X_row_stride
133
132
 
134
133
  # Mean and rstd are the same shape so have the same strides
135
- mean = tl.load(
136
- Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
- )
138
- rstd = tl.load(
139
- RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
- )
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
141
136
 
142
137
  c1 = 0.0
143
138
  c2 = 0.0
144
139
  block_range = tl.arange(0, BLOCK_SIZE)
145
140
 
146
141
  # We need to compute the sum terms of the backprop equations across all channels in the group
147
- for channel_idx in range(
148
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
- ):
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
150
143
  dW = 0.0
151
144
  dB = 0.0
152
145
  # Move the pointers to the correct channel
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
181
174
  c1 = c1 / N
182
175
  c2 = c2 / N
183
176
 
184
- for channel_idx in tl.range(
185
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
- ):
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
187
178
  # Move the pointers to the correct channel
188
179
  W = tl.load(W_ptr + channel_idx)
189
180
  for i in range(0, hidden_size, BLOCK_SIZE):
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
203
194
  x_hat = (X - mean) * rstd
204
195
  wdy = W * UPSTREAM_grad
205
196
  dx = (wdy - (x_hat * c1 + c2)) * rstd
206
- tl.store(
207
- DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
- )
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
209
198
 
210
199
 
211
200
  def group_norm_forward(X, num_channels, num_groups, W, B, eps):
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
216
205
  X = X.view(batch_size, num_groups, -1).contiguous()
217
206
  hidden_size = X.shape[-1]
218
207
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
- Y = torch.empty(
220
- (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
- )
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
222
209
  Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
210
  RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
211
 
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
307
294
  )
308
295
  ctx.num_channels = num_channels
309
296
  ctx.num_groups = num_groups
310
- ctx.save_for_backward(
311
- X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
- )
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
313
298
  return Y
314
299
 
315
300
  @staticmethod
316
301
  @ensure_contiguous
317
302
  def backward(ctx, dY):
318
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
- DX, DW, DB = group_norm_backward(
320
- dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
- )
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
322
305
  return DX, DW, DB, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
98
98
  loss_stride=loss.stride(-2),
99
99
  dX_ptr=dX,
100
100
  dX_stride=dX.stride(-2),
101
- label_ptr=(
102
- shift_labels if has_label else torch.empty(1, device=_input.device)
103
- ), # dummy ptr if no label
101
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
104
102
  beta=beta,
105
103
  n_non_ignore=n_non_ignore,
106
104
  ignore_index=ignore_index,
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
165
163
  shift_labels = shift_labels.contiguous()
166
164
  has_label = True
167
165
 
168
- loss, dX = jsd_forward(
169
- _input, target, shift_labels, beta, ignore_index, has_label
170
- )
166
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
171
167
  ctx.save_for_backward(dX)
172
168
  return loss
173
169
 
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
8
9
 
9
10
 
10
11
  def get_num_warps(BLOCK_SIZE):
@@ -23,10 +24,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
24
 
24
25
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
26
 
26
- _REDUCTION_MODE_NONE = tl.constexpr(0)
27
- _REDUCTION_MODE_SUM = tl.constexpr(1)
28
- _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
- _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
27
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
28
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
29
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
30
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
30
31
 
31
32
  _str_to_reduction_mode = {
32
33
  "none": _REDUCTION_MODE_NONE.value,
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
219
  ctx.save_for_backward(y_true)
219
220
  ctx.reduction = reduction
220
221
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
222
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
223
 
225
224
  @staticmethod
226
225
  @ensure_contiguous
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
237
 
239
238
  new_grads = torch.empty_like(y_true)
240
239
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
240
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
241
 
245
242
  if ctx.reduction == "batchmean":
246
243
  derivative = derivative / y_true.shape[0]
@@ -5,11 +5,9 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
13
11
 
14
12
  if compare_version("triton", operator.ge, "3.0.0"):
15
13
  try:
@@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
10
10
  cos,
11
11
  sin,
12
12
  sl,
13
+ bs: tl.constexpr,
13
14
  n_qh: tl.constexpr,
14
15
  n_kh: tl.constexpr,
15
16
  hd: tl.constexpr,
@@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
41
42
  t_end = mrope_section_t
42
43
  h_end = t_end + mrope_section_h
43
44
 
44
- cos_row_idx = pid % sl
45
- t_cos = cos + cos_row_idx * hd
46
- h_cos = t_cos + sl * hd
47
- w_cos = h_cos + sl * hd
48
- t_sin = sin + cos_row_idx * hd
49
- h_sin = t_sin + sl * hd
50
- w_sin = h_sin + sl * hd
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
51
 
52
52
  cos_offsets = tl.arange(0, pad_hd // 2)
53
53
  t_mask = cos_offsets < t_end
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
67
67
  # program instance (i.e. for the current token) separately
68
68
  # ####################################################################
69
69
  # left half of the head
70
- first_half_q_offsets = (
71
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
- )
73
- first_half_k_offsets = (
74
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
- )
76
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
- )
79
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
- )
82
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
- sin_row.dtype
84
- )
85
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
- sin_row.dtype
87
- )
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
88
76
 
89
77
  # right half of the head
90
78
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
79
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
80
  second_q_mask = first_q_mask
93
81
  second_k_mask = first_k_mask
94
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
- sin_row.dtype
96
- )
97
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
- sin_row.dtype
99
- )
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
100
84
 
101
85
  if not BACKWARD_PASS:
102
86
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
124
108
 
125
109
 
126
110
  def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
-
128
111
  # transpose it back to the physical shape because Triton looks at the physical storage
129
112
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
113
  q = q.transpose(1, 2)
@@ -151,6 +134,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
151
134
  cos,
152
135
  sin,
153
136
  seq_len,
137
+ batch_size,
154
138
  n_q_head,
155
139
  n_kv_head,
156
140
  head_dim,
@@ -189,6 +173,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
189
173
  cos,
190
174
  sin,
191
175
  seq_len,
176
+ batch_size,
192
177
  n_q_head,
193
178
  n_kv_head,
194
179
  head_dim,
@@ -216,8 +201,8 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
216
201
  """
217
202
  q size: (bsz, n_q_head, seq_len, head_dim)
218
203
  k size: (bsz, n_kv_head, seq_len, head_dim)
219
- cos size: (3, 1, seq_len, head_dim)
220
- sin size: (3, 1, seq_len, head_dim)
204
+ cos size: (3, bsz, seq_len, head_dim)
205
+ sin size: (3, bsz, seq_len, head_dim)
221
206
  """
222
207
  q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
208
  ctx.save_for_backward(cos, sin)
@@ -228,10 +213,9 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
228
213
  """
229
214
  dq size: (bsz, n_q_head, seq_len, head_dim)
230
215
  dk size: (bsz, n_kv_head, seq_len, head_dim)
231
- cos size: (3, 1, seq_len, head_dim)
232
- sin size: (3, 1, seq_len, head_dim)
216
+ cos size: (3, bsz, seq_len, head_dim)
217
+ sin size: (3, bsz, seq_len, head_dim)
233
218
  """
234
-
235
219
  cos, sin = ctx.saved_tensors
236
220
  mrope_section = ctx.mrope_section
237
221
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -35,9 +33,9 @@ else:
35
33
  from triton.language.math import rsqrt
36
34
 
37
35
 
38
- _CASTING_MODE_NONE = tl.constexpr(-1)
39
- _CASTING_MODE_LLAMA = tl.constexpr(0)
40
- _CASTING_MODE_GEMMA = tl.constexpr(1)
36
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
41
39
 
42
40
 
43
41
  @triton.jit
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
177
175
 
178
176
  dX_row = rstd_row * m
179
177
 
180
- dX_row += (rstd_row) * (
181
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
183
179
 
184
180
  # calculate the gradient of W
185
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
207
203
 
208
204
  def rms_norm_forward(X, W, eps, offset, casting_mode):
209
205
  if not isinstance(casting_mode, int):
210
- assert (
211
- casting_mode in _str_to_casting_mode
212
- ), f"Invalid casting mode: {casting_mode}"
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
213
207
  casting_mode = _str_to_casting_mode[casting_mode]
214
208
  else:
215
- assert (
216
- casting_mode in _str_to_casting_mode.values()
217
- ), f"Invalid casting mode: {casting_mode}"
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
218
210
 
219
211
  shape = X.shape
220
212
  dim = shape[-1]
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
225
217
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226
218
  # RSTD is to cache rstd for each row
227
219
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
- rstd_dtype = (
229
- torch.float32
230
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
- else X.dtype
232
- )
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
233
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234
222
 
235
223
  # Check constraints.
236
- assert (
237
- X.shape[1] == W.shape[0]
238
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239
225
 
240
226
  _rms_norm_forward_kernel[(n_rows,)](
241
227
  Y,
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
256
242
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257
243
 
258
244
 
259
- def rms_norm_backward(
260
- dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
- ):
245
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
262
246
  shape = dY.shape
263
247
  dim = shape[-1]
264
248
  dY = dY.view(-1, dim)
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
340
324
  X: (B, T, H) or (BxT, H)
341
325
  W: (H,)
342
326
  """
343
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344
- X, W, eps, offset, casting_mode
345
- )
327
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
346
328
  ctx.offset = offset
347
329
  ctx.casting_mode = casting_mode
348
330
  ctx.in_place = in_place
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -61,36 +72,20 @@ def _triton_rope(
61
72
  # program instance (i.e. for the current token) separately
62
73
  # ####################################################################
63
74
  # left half of the head
64
- first_half_q_offsets = (
65
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
- )
67
- first_half_k_offsets = (
68
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
- )
70
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
- )
73
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
- )
76
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
- sin_row.dtype
78
- )
79
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
- sin_row.dtype
81
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
82
81
 
83
82
  # right half of the head
84
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
85
  second_q_mask = first_q_mask
87
86
  second_k_mask = first_k_mask
88
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
- sin_row.dtype
90
- )
91
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
- sin_row.dtype
93
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
94
89
 
95
90
  if not BACKWARD_PASS:
96
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -118,7 +113,6 @@ def _triton_rope(
118
113
 
119
114
 
120
115
  def rope_forward(q, k, cos, sin):
121
-
122
116
  # transpose it back to the physical shape because Triton looks at the physical storage
123
117
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
118
  q = q.transpose(1, 2)
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
138
132
  k = k.contiguous()
139
133
  cos = cos.contiguous()
140
134
  sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
141
136
 
142
137
  _triton_rope[(n_row,)](
143
138
  q,
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
150
145
  sin.stride(-2),
151
146
  seq_len,
152
147
  batch_size,
148
+ cos_batch_size,
153
149
  n_q_head,
154
150
  n_kv_head,
155
151
  head_dim,
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
167
163
  dk = dk.transpose(1, 2)
168
164
 
169
165
  batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
170
167
  n_kv_head = dk.shape[2]
171
168
  pad_hd = triton.next_power_of_2(head_dim)
172
169
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
191
188
  sin.stride(-2),
192
189
  seq_len,
193
190
  batch_size,
191
+ cos_batch_size,
194
192
  n_q_head,
195
193
  n_kv_head,
196
194
  head_dim,
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
219
  """
222
220
  q size: (bsz, n_q_head, seq_len, head_dim)
223
221
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
224
  """
227
225
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
226
  ctx.save_for_backward(cos, sin)
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
230
  """
233
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
235
  """
238
236
 
239
237
  cos, sin = ctx.saved_tensors