liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +31 -33
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/rope.py +2 -2
  46. liger_kernel/transformers/swiglu.py +2 -8
  47. liger_kernel/transformers/trainer/__init__.py +1 -3
  48. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  49. liger_kernel/triton/__init__.py +1 -3
  50. liger_kernel/triton/monkey_patch.py +1 -3
  51. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  52. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  53. liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
  54. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import compare_version, ensure_contiguous
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
8
9
 
9
10
  if compare_version("triton", operator.ge, "3.0.0"):
10
11
  try:
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
73
74
 
74
75
  # Normalize
75
76
  hidden_size_per_channel = hidden_size // channels_per_group
76
- for channel_idx in tl.range(
77
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
- ):
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
78
  W = tl.load(W_ptr + channel_idx)
80
79
  B = tl.load(B_ptr + channel_idx)
81
80
  for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
132
131
  UPSTREAM_ptr += batch_idx * X_row_stride
133
132
 
134
133
  # Mean and rstd are the same shape so have the same strides
135
- mean = tl.load(
136
- Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
- )
138
- rstd = tl.load(
139
- RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
- )
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
141
136
 
142
137
  c1 = 0.0
143
138
  c2 = 0.0
144
139
  block_range = tl.arange(0, BLOCK_SIZE)
145
140
 
146
141
  # We need to compute the sum terms of the backprop equations across all channels in the group
147
- for channel_idx in range(
148
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
- ):
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
150
143
  dW = 0.0
151
144
  dB = 0.0
152
145
  # Move the pointers to the correct channel
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
181
174
  c1 = c1 / N
182
175
  c2 = c2 / N
183
176
 
184
- for channel_idx in tl.range(
185
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
- ):
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
187
178
  # Move the pointers to the correct channel
188
179
  W = tl.load(W_ptr + channel_idx)
189
180
  for i in range(0, hidden_size, BLOCK_SIZE):
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
203
194
  x_hat = (X - mean) * rstd
204
195
  wdy = W * UPSTREAM_grad
205
196
  dx = (wdy - (x_hat * c1 + c2)) * rstd
206
- tl.store(
207
- DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
- )
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
209
198
 
210
199
 
211
200
  def group_norm_forward(X, num_channels, num_groups, W, B, eps):
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
216
205
  X = X.view(batch_size, num_groups, -1).contiguous()
217
206
  hidden_size = X.shape[-1]
218
207
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
- Y = torch.empty(
220
- (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
- )
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
222
209
  Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
210
  RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
211
 
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
307
294
  )
308
295
  ctx.num_channels = num_channels
309
296
  ctx.num_groups = num_groups
310
- ctx.save_for_backward(
311
- X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
- )
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
313
298
  return Y
314
299
 
315
300
  @staticmethod
316
301
  @ensure_contiguous
317
302
  def backward(ctx, dY):
318
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
- DX, DW, DB = group_norm_backward(
320
- dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
- )
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
322
305
  return DX, DW, DB, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
98
98
  loss_stride=loss.stride(-2),
99
99
  dX_ptr=dX,
100
100
  dX_stride=dX.stride(-2),
101
- label_ptr=(
102
- shift_labels if has_label else torch.empty(1, device=_input.device)
103
- ), # dummy ptr if no label
101
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
104
102
  beta=beta,
105
103
  n_non_ignore=n_non_ignore,
106
104
  ignore_index=ignore_index,
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
165
163
  shift_labels = shift_labels.contiguous()
166
164
  has_label = True
167
165
 
168
- loss, dX = jsd_forward(
169
- _input, target, shift_labels, beta, ignore_index, has_label
170
- )
166
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
171
167
  ctx.save_for_backward(dX)
172
168
  return loss
173
169
 
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
8
9
 
9
10
 
10
11
  def get_num_warps(BLOCK_SIZE):
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
219
  ctx.save_for_backward(y_true)
219
220
  ctx.reduction = reduction
220
221
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
222
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
223
 
225
224
  @staticmethod
226
225
  @ensure_contiguous
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
237
 
239
238
  new_grads = torch.empty_like(y_true)
240
239
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
240
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
241
 
245
242
  if ctx.reduction == "batchmean":
246
243
  derivative = derivative / y_true.shape[0]
@@ -5,11 +5,9 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
13
11
 
14
12
  if compare_version("triton", operator.ge, "3.0.0"):
15
13
  try:
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
67
67
  # program instance (i.e. for the current token) separately
68
68
  # ####################################################################
69
69
  # left half of the head
70
- first_half_q_offsets = (
71
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
- )
73
- first_half_k_offsets = (
74
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
- )
76
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
- )
79
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
- )
82
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
- sin_row.dtype
84
- )
85
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
- sin_row.dtype
87
- )
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
88
76
 
89
77
  # right half of the head
90
78
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
79
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
80
  second_q_mask = first_q_mask
93
81
  second_k_mask = first_k_mask
94
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
- sin_row.dtype
96
- )
97
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
- sin_row.dtype
99
- )
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
100
84
 
101
85
  if not BACKWARD_PASS:
102
86
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
124
108
 
125
109
 
126
110
  def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
-
128
111
  # transpose it back to the physical shape because Triton looks at the physical storage
129
112
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
113
  q = q.transpose(1, 2)
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
177
175
 
178
176
  dX_row = rstd_row * m
179
177
 
180
- dX_row += (rstd_row) * (
181
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
183
179
 
184
180
  # calculate the gradient of W
185
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
207
203
 
208
204
  def rms_norm_forward(X, W, eps, offset, casting_mode):
209
205
  if not isinstance(casting_mode, int):
210
- assert (
211
- casting_mode in _str_to_casting_mode
212
- ), f"Invalid casting mode: {casting_mode}"
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
213
207
  casting_mode = _str_to_casting_mode[casting_mode]
214
208
  else:
215
- assert (
216
- casting_mode in _str_to_casting_mode.values()
217
- ), f"Invalid casting mode: {casting_mode}"
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
218
210
 
219
211
  shape = X.shape
220
212
  dim = shape[-1]
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
225
217
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226
218
  # RSTD is to cache rstd for each row
227
219
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
- rstd_dtype = (
229
- torch.float32
230
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
- else X.dtype
232
- )
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
233
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234
222
 
235
223
  # Check constraints.
236
- assert (
237
- X.shape[1] == W.shape[0]
238
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239
225
 
240
226
  _rms_norm_forward_kernel[(n_rows,)](
241
227
  Y,
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
256
242
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257
243
 
258
244
 
259
- def rms_norm_backward(
260
- dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
- ):
245
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
262
246
  shape = dY.shape
263
247
  dim = shape[-1]
264
248
  dY = dY.view(-1, dim)
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
340
324
  X: (B, T, H) or (BxT, H)
341
325
  W: (H,)
342
326
  """
343
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344
- X, W, eps, offset, casting_mode
345
- )
327
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
346
328
  ctx.offset = offset
347
329
  ctx.casting_mode = casting_mode
348
330
  ctx.in_place = in_place
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -61,36 +72,20 @@ def _triton_rope(
61
72
  # program instance (i.e. for the current token) separately
62
73
  # ####################################################################
63
74
  # left half of the head
64
- first_half_q_offsets = (
65
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
- )
67
- first_half_k_offsets = (
68
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
- )
70
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
- )
73
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
- )
76
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
- sin_row.dtype
78
- )
79
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
- sin_row.dtype
81
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
82
81
 
83
82
  # right half of the head
84
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
85
  second_q_mask = first_q_mask
87
86
  second_k_mask = first_k_mask
88
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
- sin_row.dtype
90
- )
91
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
- sin_row.dtype
93
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
94
89
 
95
90
  if not BACKWARD_PASS:
96
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -118,7 +113,6 @@ def _triton_rope(
118
113
 
119
114
 
120
115
  def rope_forward(q, k, cos, sin):
121
-
122
116
  # transpose it back to the physical shape because Triton looks at the physical storage
123
117
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
118
  q = q.transpose(1, 2)
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
138
132
  k = k.contiguous()
139
133
  cos = cos.contiguous()
140
134
  sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
141
136
 
142
137
  _triton_rope[(n_row,)](
143
138
  q,
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
150
145
  sin.stride(-2),
151
146
  seq_len,
152
147
  batch_size,
148
+ cos_batch_size,
153
149
  n_q_head,
154
150
  n_kv_head,
155
151
  head_dim,
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
167
163
  dk = dk.transpose(1, 2)
168
164
 
169
165
  batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
170
167
  n_kv_head = dk.shape[2]
171
168
  pad_hd = triton.next_power_of_2(head_dim)
172
169
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
191
188
  sin.stride(-2),
192
189
  seq_len,
193
190
  batch_size,
191
+ cos_batch_size,
194
192
  n_q_head,
195
193
  n_kv_head,
196
194
  head_dim,
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
219
  """
222
220
  q size: (bsz, n_q_head, seq_len, head_dim)
223
221
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
224
  """
227
225
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
226
  ctx.save_for_backward(cos, sin)
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
230
  """
233
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
235
  """
238
236
 
239
237
  cos, sin = ctx.saved_tensors
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,9 +12,7 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
17
16
  program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
38
35
  program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
84
81
 
85
82
 
86
83
  def swiglu_backward(a, b, dc):
87
-
88
84
  ori_shape = dc.shape
89
85
  n_cols = ori_shape[-1]
90
86
  dc = dc.view(-1, n_cols)
liger_kernel/ops/utils.py CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
13
13
  import functools
14
14
  import importlib
15
15
  import operator
16
+
16
17
  from typing import Callable
17
18
 
18
19
  import torch
19
20
  import triton
20
21
  import triton.language as tl
22
+
21
23
  from packaging.version import Version
22
24
 
23
25
  from liger_kernel.utils import infer_device
@@ -1,31 +1,23 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
4
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
3
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
4
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
5
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
6
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11
7
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
8
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
16
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
25
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
21
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
27
27
  "sum",
28
28
  "none",
29
29
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
- assert (
31
- softcap is None or softcap > 0
32
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
30
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
33
31
  self.ignore_index = ignore_index
34
32
  self.lse_square_scale = lse_square_scale
35
33
  self.label_smoothing = label_smoothing
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim