liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -12
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +8 -24
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/swiglu.py +2 -8
  46. liger_kernel/transformers/trainer/__init__.py +1 -3
  47. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  48. liger_kernel/triton/__init__.py +1 -3
  49. liger_kernel/triton/monkey_patch.py +1 -3
  50. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
  51. liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
  52. liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
  53. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import compare_version, ensure_contiguous
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
8
9
 
9
10
  if compare_version("triton", operator.ge, "3.0.0"):
10
11
  try:
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
73
74
 
74
75
  # Normalize
75
76
  hidden_size_per_channel = hidden_size // channels_per_group
76
- for channel_idx in tl.range(
77
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
- ):
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
78
  W = tl.load(W_ptr + channel_idx)
80
79
  B = tl.load(B_ptr + channel_idx)
81
80
  for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
132
131
  UPSTREAM_ptr += batch_idx * X_row_stride
133
132
 
134
133
  # Mean and rstd are the same shape so have the same strides
135
- mean = tl.load(
136
- Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
- )
138
- rstd = tl.load(
139
- RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
- )
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
141
136
 
142
137
  c1 = 0.0
143
138
  c2 = 0.0
144
139
  block_range = tl.arange(0, BLOCK_SIZE)
145
140
 
146
141
  # We need to compute the sum terms of the backprop equations across all channels in the group
147
- for channel_idx in range(
148
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
- ):
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
150
143
  dW = 0.0
151
144
  dB = 0.0
152
145
  # Move the pointers to the correct channel
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
181
174
  c1 = c1 / N
182
175
  c2 = c2 / N
183
176
 
184
- for channel_idx in tl.range(
185
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
- ):
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
187
178
  # Move the pointers to the correct channel
188
179
  W = tl.load(W_ptr + channel_idx)
189
180
  for i in range(0, hidden_size, BLOCK_SIZE):
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
203
194
  x_hat = (X - mean) * rstd
204
195
  wdy = W * UPSTREAM_grad
205
196
  dx = (wdy - (x_hat * c1 + c2)) * rstd
206
- tl.store(
207
- DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
- )
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
209
198
 
210
199
 
211
200
  def group_norm_forward(X, num_channels, num_groups, W, B, eps):
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
216
205
  X = X.view(batch_size, num_groups, -1).contiguous()
217
206
  hidden_size = X.shape[-1]
218
207
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
- Y = torch.empty(
220
- (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
- )
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
222
209
  Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
210
  RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
211
 
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
307
294
  )
308
295
  ctx.num_channels = num_channels
309
296
  ctx.num_groups = num_groups
310
- ctx.save_for_backward(
311
- X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
- )
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
313
298
  return Y
314
299
 
315
300
  @staticmethod
316
301
  @ensure_contiguous
317
302
  def backward(ctx, dY):
318
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
- DX, DW, DB = group_norm_backward(
320
- dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
- )
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
322
305
  return DX, DW, DB, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
98
98
  loss_stride=loss.stride(-2),
99
99
  dX_ptr=dX,
100
100
  dX_stride=dX.stride(-2),
101
- label_ptr=(
102
- shift_labels if has_label else torch.empty(1, device=_input.device)
103
- ), # dummy ptr if no label
101
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
104
102
  beta=beta,
105
103
  n_non_ignore=n_non_ignore,
106
104
  ignore_index=ignore_index,
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
165
163
  shift_labels = shift_labels.contiguous()
166
164
  has_label = True
167
165
 
168
- loss, dX = jsd_forward(
169
- _input, target, shift_labels, beta, ignore_index, has_label
170
- )
166
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
171
167
  ctx.save_for_backward(dX)
172
168
  return loss
173
169
 
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
8
9
 
9
10
 
10
11
  def get_num_warps(BLOCK_SIZE):
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
219
  ctx.save_for_backward(y_true)
219
220
  ctx.reduction = reduction
220
221
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
222
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
223
 
225
224
  @staticmethod
226
225
  @ensure_contiguous
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
237
 
239
238
  new_grads = torch.empty_like(y_true)
240
239
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
240
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
241
 
245
242
  if ctx.reduction == "batchmean":
246
243
  derivative = derivative / y_true.shape[0]
@@ -5,11 +5,9 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
13
11
 
14
12
  if compare_version("triton", operator.ge, "3.0.0"):
15
13
  try:
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
67
67
  # program instance (i.e. for the current token) separately
68
68
  # ####################################################################
69
69
  # left half of the head
70
- first_half_q_offsets = (
71
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
- )
73
- first_half_k_offsets = (
74
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
- )
76
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
- )
79
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
- )
82
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
- sin_row.dtype
84
- )
85
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
- sin_row.dtype
87
- )
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
88
76
 
89
77
  # right half of the head
90
78
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
79
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
80
  second_q_mask = first_q_mask
93
81
  second_k_mask = first_k_mask
94
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
- sin_row.dtype
96
- )
97
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
- sin_row.dtype
99
- )
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
100
84
 
101
85
  if not BACKWARD_PASS:
102
86
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
124
108
 
125
109
 
126
110
  def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
-
128
111
  # transpose it back to the physical shape because Triton looks at the physical storage
129
112
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
113
  q = q.transpose(1, 2)
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
177
175
 
178
176
  dX_row = rstd_row * m
179
177
 
180
- dX_row += (rstd_row) * (
181
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
183
179
 
184
180
  # calculate the gradient of W
185
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
207
203
 
208
204
  def rms_norm_forward(X, W, eps, offset, casting_mode):
209
205
  if not isinstance(casting_mode, int):
210
- assert (
211
- casting_mode in _str_to_casting_mode
212
- ), f"Invalid casting mode: {casting_mode}"
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
213
207
  casting_mode = _str_to_casting_mode[casting_mode]
214
208
  else:
215
- assert (
216
- casting_mode in _str_to_casting_mode.values()
217
- ), f"Invalid casting mode: {casting_mode}"
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
218
210
 
219
211
  shape = X.shape
220
212
  dim = shape[-1]
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
225
217
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226
218
  # RSTD is to cache rstd for each row
227
219
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
- rstd_dtype = (
229
- torch.float32
230
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
- else X.dtype
232
- )
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
233
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234
222
 
235
223
  # Check constraints.
236
- assert (
237
- X.shape[1] == W.shape[0]
238
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239
225
 
240
226
  _rms_norm_forward_kernel[(n_rows,)](
241
227
  Y,
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
256
242
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257
243
 
258
244
 
259
- def rms_norm_backward(
260
- dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
- ):
245
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
262
246
  shape = dY.shape
263
247
  dim = shape[-1]
264
248
  dY = dY.view(-1, dim)
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
340
324
  X: (B, T, H) or (BxT, H)
341
325
  W: (H,)
342
326
  """
343
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344
- X, W, eps, offset, casting_mode
345
- )
327
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
346
328
  ctx.offset = offset
347
329
  ctx.casting_mode = casting_mode
348
330
  ctx.in_place = in_place
liger_kernel/ops/rope.py CHANGED
@@ -72,36 +72,20 @@ def _triton_rope(
72
72
  # program instance (i.e. for the current token) separately
73
73
  # ####################################################################
74
74
  # left half of the head
75
- first_half_q_offsets = (
76
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
- )
78
- first_half_k_offsets = (
79
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
80
- )
81
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
82
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
83
- )
84
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
85
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
86
- )
87
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
88
- sin_row.dtype
89
- )
90
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
91
- sin_row.dtype
92
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
93
81
 
94
82
  # right half of the head
95
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
96
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
97
85
  second_q_mask = first_q_mask
98
86
  second_k_mask = first_k_mask
99
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
100
- sin_row.dtype
101
- )
102
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
103
- sin_row.dtype
104
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
105
89
 
106
90
  if not BACKWARD_PASS:
107
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,9 +12,7 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
17
16
  program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
38
35
  program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
84
81
 
85
82
 
86
83
  def swiglu_backward(a, b, dc):
87
-
88
84
  ori_shape = dc.shape
89
85
  n_cols = ori_shape[-1]
90
86
  dc = dc.view(-1, n_cols)
liger_kernel/ops/utils.py CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
13
13
  import functools
14
14
  import importlib
15
15
  import operator
16
+
16
17
  from typing import Callable
17
18
 
18
19
  import torch
19
20
  import triton
20
21
  import triton.language as tl
22
+
21
23
  from packaging.version import Version
22
24
 
23
25
  from liger_kernel.utils import infer_device
@@ -1,31 +1,23 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
4
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
3
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
4
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
5
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
6
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11
7
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
8
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
16
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
25
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
21
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
27
27
  "sum",
28
28
  "none",
29
29
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
- assert (
31
- softcap is None or softcap > 0
32
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
30
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
33
31
  self.ignore_index = ignore_index
34
32
  self.lse_square_scale = lse_square_scale
35
33
  self.label_smoothing = label_smoothing
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim
@@ -1,9 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
4
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
5
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
7
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
@@ -159,9 +157,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
157
  return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
158
 
161
159
 
162
- def liger_rms_norm(
163
- X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
- ):
160
+ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
165
161
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
162
 
167
163