liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -12
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +8 -24
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/swiglu.py +2 -8
  46. liger_kernel/transformers/trainer/__init__.py +1 -3
  47. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  48. liger_kernel/triton/__init__.py +1 -3
  49. liger_kernel/triton/monkey_patch.py +1 -3
  50. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
  51. liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
  52. liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
  53. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import compare_version, ensure_contiguous
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
8
9
 
9
10
  if compare_version("triton", operator.ge, "3.0.0"):
10
11
  try:
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
73
74
 
74
75
  # Normalize
75
76
  hidden_size_per_channel = hidden_size // channels_per_group
76
- for channel_idx in tl.range(
77
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
- ):
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
78
  W = tl.load(W_ptr + channel_idx)
80
79
  B = tl.load(B_ptr + channel_idx)
81
80
  for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
132
131
  UPSTREAM_ptr += batch_idx * X_row_stride
133
132
 
134
133
  # Mean and rstd are the same shape so have the same strides
135
- mean = tl.load(
136
- Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
- )
138
- rstd = tl.load(
139
- RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
- )
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
141
136
 
142
137
  c1 = 0.0
143
138
  c2 = 0.0
144
139
  block_range = tl.arange(0, BLOCK_SIZE)
145
140
 
146
141
  # We need to compute the sum terms of the backprop equations across all channels in the group
147
- for channel_idx in range(
148
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
- ):
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
150
143
  dW = 0.0
151
144
  dB = 0.0
152
145
  # Move the pointers to the correct channel
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
181
174
  c1 = c1 / N
182
175
  c2 = c2 / N
183
176
 
184
- for channel_idx in tl.range(
185
- group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
- ):
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
187
178
  # Move the pointers to the correct channel
188
179
  W = tl.load(W_ptr + channel_idx)
189
180
  for i in range(0, hidden_size, BLOCK_SIZE):
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
203
194
  x_hat = (X - mean) * rstd
204
195
  wdy = W * UPSTREAM_grad
205
196
  dx = (wdy - (x_hat * c1 + c2)) * rstd
206
- tl.store(
207
- DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
- )
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
209
198
 
210
199
 
211
200
  def group_norm_forward(X, num_channels, num_groups, W, B, eps):
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
216
205
  X = X.view(batch_size, num_groups, -1).contiguous()
217
206
  hidden_size = X.shape[-1]
218
207
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
- Y = torch.empty(
220
- (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
- )
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
222
209
  Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
210
  RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
211
 
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
307
294
  )
308
295
  ctx.num_channels = num_channels
309
296
  ctx.num_groups = num_groups
310
- ctx.save_for_backward(
311
- X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
- )
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
313
298
  return Y
314
299
 
315
300
  @staticmethod
316
301
  @ensure_contiguous
317
302
  def backward(ctx, dY):
318
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
- DX, DW, DB = group_norm_backward(
320
- dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
- )
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
322
305
  return DX, DW, DB, None, None, None
liger_kernel/ops/jsd.py CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
98
98
  loss_stride=loss.stride(-2),
99
99
  dX_ptr=dX,
100
100
  dX_stride=dX.stride(-2),
101
- label_ptr=(
102
- shift_labels if has_label else torch.empty(1, device=_input.device)
103
- ), # dummy ptr if no label
101
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
104
102
  beta=beta,
105
103
  n_non_ignore=n_non_ignore,
106
104
  ignore_index=ignore_index,
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
165
163
  shift_labels = shift_labels.contiguous()
166
164
  has_label = True
167
165
 
168
- loss, dX = jsd_forward(
169
- _input, target, shift_labels, beta, ignore_index, has_label
170
- )
166
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
171
167
  ctx.save_for_backward(dX)
172
168
  return loss
173
169
 
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
8
9
 
9
10
 
10
11
  def get_num_warps(BLOCK_SIZE):
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
219
  ctx.save_for_backward(y_true)
219
220
  ctx.reduction = reduction
220
221
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
222
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
223
 
225
224
  @staticmethod
226
225
  @ensure_contiguous
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
237
 
239
238
  new_grads = torch.empty_like(y_true)
240
239
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
240
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
241
 
245
242
  if ctx.reduction == "batchmean":
246
243
  derivative = derivative / y_true.shape[0]
@@ -5,11 +5,9 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
13
11
 
14
12
  if compare_version("triton", operator.ge, "3.0.0"):
15
13
  try:
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
67
67
  # program instance (i.e. for the current token) separately
68
68
  # ####################################################################
69
69
  # left half of the head
70
- first_half_q_offsets = (
71
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
- )
73
- first_half_k_offsets = (
74
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
- )
76
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
- )
79
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
- )
82
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
- sin_row.dtype
84
- )
85
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
- sin_row.dtype
87
- )
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
88
76
 
89
77
  # right half of the head
90
78
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
79
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
80
  second_q_mask = first_q_mask
93
81
  second_k_mask = first_k_mask
94
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
- sin_row.dtype
96
- )
97
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
- sin_row.dtype
99
- )
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
100
84
 
101
85
  if not BACKWARD_PASS:
102
86
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
124
108
 
125
109
 
126
110
  def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
-
128
111
  # transpose it back to the physical shape because Triton looks at the physical storage
129
112
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
113
  q = q.transpose(1, 2)
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
177
175
 
178
176
  dX_row = rstd_row * m
179
177
 
180
- dX_row += (rstd_row) * (
181
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
183
179
 
184
180
  # calculate the gradient of W
185
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
207
203
 
208
204
  def rms_norm_forward(X, W, eps, offset, casting_mode):
209
205
  if not isinstance(casting_mode, int):
210
- assert (
211
- casting_mode in _str_to_casting_mode
212
- ), f"Invalid casting mode: {casting_mode}"
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
213
207
  casting_mode = _str_to_casting_mode[casting_mode]
214
208
  else:
215
- assert (
216
- casting_mode in _str_to_casting_mode.values()
217
- ), f"Invalid casting mode: {casting_mode}"
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
218
210
 
219
211
  shape = X.shape
220
212
  dim = shape[-1]
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
225
217
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226
218
  # RSTD is to cache rstd for each row
227
219
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228
- rstd_dtype = (
229
- torch.float32
230
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231
- else X.dtype
232
- )
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
233
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234
222
 
235
223
  # Check constraints.
236
- assert (
237
- X.shape[1] == W.shape[0]
238
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239
225
 
240
226
  _rms_norm_forward_kernel[(n_rows,)](
241
227
  Y,
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
256
242
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257
243
 
258
244
 
259
- def rms_norm_backward(
260
- dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
- ):
245
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
262
246
  shape = dY.shape
263
247
  dim = shape[-1]
264
248
  dY = dY.view(-1, dim)
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
340
324
  X: (B, T, H) or (BxT, H)
341
325
  W: (H,)
342
326
  """
343
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344
- X, W, eps, offset, casting_mode
345
- )
327
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
346
328
  ctx.offset = offset
347
329
  ctx.casting_mode = casting_mode
348
330
  ctx.in_place = in_place
liger_kernel/ops/rope.py CHANGED
@@ -72,36 +72,20 @@ def _triton_rope(
72
72
  # program instance (i.e. for the current token) separately
73
73
  # ####################################################################
74
74
  # left half of the head
75
- first_half_q_offsets = (
76
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
- )
78
- first_half_k_offsets = (
79
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
80
- )
81
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
82
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
83
- )
84
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
85
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
86
- )
87
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
88
- sin_row.dtype
89
- )
90
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
91
- sin_row.dtype
92
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
93
81
 
94
82
  # right half of the head
95
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
96
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
97
85
  second_q_mask = first_q_mask
98
86
  second_k_mask = first_k_mask
99
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
100
- sin_row.dtype
101
- )
102
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
103
- sin_row.dtype
104
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
105
89
 
106
90
  if not BACKWARD_PASS:
107
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,9 +12,7 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
17
16
  program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
38
35
  program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
84
81
 
85
82
 
86
83
  def swiglu_backward(a, b, dc):
87
-
88
84
  ori_shape = dc.shape
89
85
  n_cols = ori_shape[-1]
90
86
  dc = dc.view(-1, n_cols)
liger_kernel/ops/utils.py CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
13
13
  import functools
14
14
  import importlib
15
15
  import operator
16
+
16
17
  from typing import Callable
17
18
 
18
19
  import torch
19
20
  import triton
20
21
  import triton.language as tl
22
+
21
23
  from packaging.version import Version
22
24
 
23
25
  from liger_kernel.utils import infer_device
@@ -1,31 +1,23 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
4
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
3
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
4
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
5
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
6
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11
7
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
8
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
16
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
25
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
21
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
22
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
23
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
27
27
  "sum",
28
28
  "none",
29
29
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
- assert (
31
- softcap is None or softcap > 0
32
- ), f"softcap must greater than 0.0 or None. Got: {softcap}"
30
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
33
31
  self.ignore_index = ignore_index
34
32
  self.lse_square_scale = lse_square_scale
35
33
  self.label_smoothing = label_smoothing
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim
@@ -1,9 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
4
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
5
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
6
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
7
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
@@ -159,9 +157,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
157
  return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
158
 
161
159
 
162
- def liger_rms_norm(
163
- X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
- ):
160
+ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
165
161
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
162
 
167
163