liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +5 -12
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
is_hip,
|
12
|
-
)
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
8
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
9
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
10
|
+
from liger_kernel.ops.utils import is_hip
|
13
11
|
|
14
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
15
13
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
|
|
43
41
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
44
42
|
|
45
43
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
46
|
-
chunk_size = triton.next_power_of_2(
|
47
|
-
triton.cdiv(BT, inc_factor)
|
48
|
-
) # (BT + inc_factor - 1) // inc_factor
|
44
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
49
45
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
50
46
|
|
51
|
-
grad_weight = (
|
52
|
-
torch.zeros_like(student_weight, device=device)
|
53
|
-
if student_weight.requires_grad
|
54
|
-
else None
|
55
|
-
)
|
47
|
+
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
|
56
48
|
grad_input = torch.zeros_like(student_input)
|
57
49
|
# we use fp32 for loss accumulator
|
58
50
|
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
|
|
73
65
|
# shape: chunk_size x V
|
74
66
|
# For anything starting from logits to the final JSD loss, we do computation
|
75
67
|
# in FP32 to avoid losing numerical stability.
|
76
|
-
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
77
|
-
|
78
|
-
)
|
79
|
-
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
80
|
-
torch.float32
|
81
|
-
)
|
68
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
|
69
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
|
82
70
|
chunk_n_rows = student_logits_chunk.shape[0]
|
83
71
|
|
84
72
|
# unreduced loss
|
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
|
|
104
92
|
dX_ptr=student_prob_chunk,
|
105
93
|
dX_stride=student_prob_chunk.stride(-2),
|
106
94
|
label_ptr=(
|
107
|
-
shift_labels[start_idx:end_idx]
|
108
|
-
if has_label
|
109
|
-
else torch.empty(1, device=device)
|
95
|
+
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
|
110
96
|
), # dummy ptr if no label
|
111
97
|
beta=jsd_beta,
|
112
98
|
n_non_ignore=n_non_ignore,
|
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
|
|
121
107
|
student_logits_chunk = (
|
122
108
|
student_prob_chunk
|
123
109
|
- torch.softmax(student_logits_chunk, dim=-1)
|
124
|
-
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
125
|
-
student_prob_chunk.shape
|
126
|
-
)
|
110
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
|
127
111
|
) / temperature
|
128
112
|
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
129
113
|
# w.r.t. `lm_head` which should be computed in original dtype
|
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
239
223
|
@amp_custom_bwd
|
240
224
|
def backward(ctx, grad_output):
|
241
225
|
(grad_input, grad_weight) = ctx.saved_tensors
|
242
|
-
grad_input, grad_weight = fused_linear_jsd_backward(
|
243
|
-
grad_output, grad_input, grad_weight
|
244
|
-
)
|
226
|
+
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
|
245
227
|
return (grad_input, grad_weight, None, None, None, None, None, None)
|
liger_kernel/ops/geglu.py
CHANGED
@@ -4,11 +4,9 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
ensure_contiguous,
|
11
|
-
)
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
12
10
|
|
13
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
12
|
try:
|
@@ -22,9 +20,7 @@ else:
|
|
22
20
|
|
23
21
|
|
24
22
|
@triton.jit
|
25
|
-
def _geglu_tanh_forward_kernel(
|
26
|
-
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
27
|
-
):
|
23
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
28
24
|
program_id = tl.program_id(0).to(tl.int64)
|
29
25
|
|
30
26
|
# locate start index
|
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
|
|
49
45
|
|
50
46
|
|
51
47
|
@triton.jit
|
52
|
-
def _geglu_tanh_backward_kernel(
|
53
|
-
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
54
|
-
):
|
48
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
55
49
|
program_id = tl.program_id(0).to(tl.int64)
|
56
50
|
|
57
51
|
# locate start index
|
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
|
|
80
74
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
81
75
|
term1 = 0.5 * (1 + tanh_result)
|
82
76
|
tanh_sq = tanh_result * tanh_result
|
83
|
-
term2 = (
|
84
|
-
0.5
|
85
|
-
* a_row
|
86
|
-
* (1 - tanh_sq)
|
87
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
88
|
-
)
|
77
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
89
78
|
da_row = dc_row * b_row * (term1 + term2)
|
90
79
|
|
91
80
|
tl.store(a + col_offsets, da_row, mask=mask)
|
liger_kernel/ops/group_norm.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import compare_version
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
9
|
|
9
10
|
if compare_version("triton", operator.ge, "3.0.0"):
|
10
11
|
try:
|
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
|
|
73
74
|
|
74
75
|
# Normalize
|
75
76
|
hidden_size_per_channel = hidden_size // channels_per_group
|
76
|
-
for channel_idx in tl.range(
|
77
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
78
|
-
):
|
77
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
79
78
|
W = tl.load(W_ptr + channel_idx)
|
80
79
|
B = tl.load(B_ptr + channel_idx)
|
81
80
|
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
|
|
132
131
|
UPSTREAM_ptr += batch_idx * X_row_stride
|
133
132
|
|
134
133
|
# Mean and rstd are the same shape so have the same strides
|
135
|
-
mean = tl.load(
|
136
|
-
|
137
|
-
)
|
138
|
-
rstd = tl.load(
|
139
|
-
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
140
|
-
)
|
134
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
135
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
141
136
|
|
142
137
|
c1 = 0.0
|
143
138
|
c2 = 0.0
|
144
139
|
block_range = tl.arange(0, BLOCK_SIZE)
|
145
140
|
|
146
141
|
# We need to compute the sum terms of the backprop equations across all channels in the group
|
147
|
-
for channel_idx in range(
|
148
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
149
|
-
):
|
142
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
150
143
|
dW = 0.0
|
151
144
|
dB = 0.0
|
152
145
|
# Move the pointers to the correct channel
|
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
|
|
181
174
|
c1 = c1 / N
|
182
175
|
c2 = c2 / N
|
183
176
|
|
184
|
-
for channel_idx in tl.range(
|
185
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
186
|
-
):
|
177
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
187
178
|
# Move the pointers to the correct channel
|
188
179
|
W = tl.load(W_ptr + channel_idx)
|
189
180
|
for i in range(0, hidden_size, BLOCK_SIZE):
|
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
|
|
203
194
|
x_hat = (X - mean) * rstd
|
204
195
|
wdy = W * UPSTREAM_grad
|
205
196
|
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
206
|
-
tl.store(
|
207
|
-
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
208
|
-
)
|
197
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
209
198
|
|
210
199
|
|
211
200
|
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
216
205
|
X = X.view(batch_size, num_groups, -1).contiguous()
|
217
206
|
hidden_size = X.shape[-1]
|
218
207
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
219
|
-
Y = torch.empty(
|
220
|
-
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
221
|
-
)
|
208
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
222
209
|
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
223
210
|
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
224
211
|
|
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
|
|
307
294
|
)
|
308
295
|
ctx.num_channels = num_channels
|
309
296
|
ctx.num_groups = num_groups
|
310
|
-
ctx.save_for_backward(
|
311
|
-
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
312
|
-
)
|
297
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
313
298
|
return Y
|
314
299
|
|
315
300
|
@staticmethod
|
316
301
|
@ensure_contiguous
|
317
302
|
def backward(ctx, dY):
|
318
303
|
X, W, B, Mean, RSTD = ctx.saved_tensors
|
319
|
-
DX, DW, DB = group_norm_backward(
|
320
|
-
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
321
|
-
)
|
304
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
322
305
|
return DX, DW, DB, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
98
98
|
loss_stride=loss.stride(-2),
|
99
99
|
dX_ptr=dX,
|
100
100
|
dX_stride=dX.stride(-2),
|
101
|
-
label_ptr=(
|
102
|
-
shift_labels if has_label else torch.empty(1, device=_input.device)
|
103
|
-
), # dummy ptr if no label
|
101
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
104
102
|
beta=beta,
|
105
103
|
n_non_ignore=n_non_ignore,
|
106
104
|
ignore_index=ignore_index,
|
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
165
163
|
shift_labels = shift_labels.contiguous()
|
166
164
|
has_label = True
|
167
165
|
|
168
|
-
loss, dX = jsd_forward(
|
169
|
-
_input, target, shift_labels, beta, ignore_index, has_label
|
170
|
-
)
|
166
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
171
167
|
ctx.save_for_backward(dX)
|
172
168
|
return loss
|
173
169
|
|
liger_kernel/ops/kl_div.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
8
9
|
|
9
10
|
|
10
11
|
def get_num_warps(BLOCK_SIZE):
|
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
218
219
|
ctx.save_for_backward(y_true)
|
219
220
|
ctx.reduction = reduction
|
220
221
|
ctx.log_target = log_target
|
221
|
-
return kldiv_forward_triton(
|
222
|
-
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
223
|
-
)
|
222
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
224
223
|
|
225
224
|
@staticmethod
|
226
225
|
@ensure_contiguous
|
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
238
237
|
|
239
238
|
new_grads = torch.empty_like(y_true)
|
240
239
|
|
241
|
-
derivative = kldiv_backward_triton(
|
242
|
-
y_true, grad_output, new_grads, ctx.log_target
|
243
|
-
)
|
240
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
244
241
|
|
245
242
|
if ctx.reduction == "batchmean":
|
246
243
|
derivative = derivative / y_true.shape[0]
|
liger_kernel/ops/layer_norm.py
CHANGED
@@ -5,11 +5,9 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from liger_kernel.ops.utils import
|
9
|
-
|
10
|
-
|
11
|
-
ensure_contiguous,
|
12
|
-
)
|
8
|
+
from liger_kernel.ops.utils import calculate_settings
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
13
11
|
|
14
12
|
if compare_version("triton", operator.ge, "3.0.0"):
|
15
13
|
try:
|
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
|
|
67
67
|
# program instance (i.e. for the current token) separately
|
68
68
|
# ####################################################################
|
69
69
|
# left half of the head
|
70
|
-
first_half_q_offsets = (
|
71
|
-
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
)
|
76
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
77
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
78
|
-
)
|
79
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
80
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
81
|
-
)
|
82
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
83
|
-
sin_row.dtype
|
84
|
-
)
|
85
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
86
|
-
sin_row.dtype
|
87
|
-
)
|
70
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
71
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
72
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
74
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
75
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
88
76
|
|
89
77
|
# right half of the head
|
90
78
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
91
79
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
92
80
|
second_q_mask = first_q_mask
|
93
81
|
second_k_mask = first_k_mask
|
94
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
95
|
-
|
96
|
-
)
|
97
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
98
|
-
sin_row.dtype
|
99
|
-
)
|
82
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
83
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
100
84
|
|
101
85
|
if not BACKWARD_PASS:
|
102
86
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
|
|
124
108
|
|
125
109
|
|
126
110
|
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
127
|
-
|
128
111
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
129
112
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
130
113
|
q = q.transpose(1, 2)
|
liger_kernel/ops/rms_norm.py
CHANGED
@@ -17,12 +17,10 @@ import torch
|
|
17
17
|
import triton
|
18
18
|
import triton.language as tl
|
19
19
|
|
20
|
-
from liger_kernel.ops.utils import
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
torch_to_triton_dtype,
|
25
|
-
)
|
20
|
+
from liger_kernel.ops.utils import calculate_settings
|
21
|
+
from liger_kernel.ops.utils import compare_version
|
22
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
23
|
+
from liger_kernel.ops.utils import torch_to_triton_dtype
|
26
24
|
|
27
25
|
if compare_version("triton", operator.ge, "3.0.0"):
|
28
26
|
try:
|
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
|
|
177
175
|
|
178
176
|
dX_row = rstd_row * m
|
179
177
|
|
180
|
-
dX_row += (rstd_row) * (
|
181
|
-
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
182
|
-
)
|
178
|
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
183
179
|
|
184
180
|
# calculate the gradient of W
|
185
181
|
if casting_mode == _CASTING_MODE_LLAMA:
|
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
|
|
207
203
|
|
208
204
|
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
209
205
|
if not isinstance(casting_mode, int):
|
210
|
-
assert
|
211
|
-
casting_mode in _str_to_casting_mode
|
212
|
-
), f"Invalid casting mode: {casting_mode}"
|
206
|
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
213
207
|
casting_mode = _str_to_casting_mode[casting_mode]
|
214
208
|
else:
|
215
|
-
assert (
|
216
|
-
casting_mode in _str_to_casting_mode.values()
|
217
|
-
), f"Invalid casting mode: {casting_mode}"
|
209
|
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
218
210
|
|
219
211
|
shape = X.shape
|
220
212
|
dim = shape[-1]
|
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
225
217
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
226
218
|
# RSTD is to cache rstd for each row
|
227
219
|
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
228
|
-
rstd_dtype = (
|
229
|
-
torch.float32
|
230
|
-
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
231
|
-
else X.dtype
|
232
|
-
)
|
220
|
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
233
221
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
234
222
|
|
235
223
|
# Check constraints.
|
236
|
-
assert
|
237
|
-
X.shape[1] == W.shape[0]
|
238
|
-
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
224
|
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
239
225
|
|
240
226
|
_rms_norm_forward_kernel[(n_rows,)](
|
241
227
|
Y,
|
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
256
242
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
257
243
|
|
258
244
|
|
259
|
-
def rms_norm_backward(
|
260
|
-
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
|
261
|
-
):
|
245
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
262
246
|
shape = dY.shape
|
263
247
|
dim = shape[-1]
|
264
248
|
dY = dY.view(-1, dim)
|
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
340
324
|
X: (B, T, H) or (BxT, H)
|
341
325
|
W: (H,)
|
342
326
|
"""
|
343
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
344
|
-
X, W, eps, offset, casting_mode
|
345
|
-
)
|
327
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
346
328
|
ctx.offset = offset
|
347
329
|
ctx.casting_mode = casting_mode
|
348
330
|
ctx.in_place = in_place
|
liger_kernel/ops/rope.py
CHANGED
@@ -72,36 +72,20 @@ def _triton_rope(
|
|
72
72
|
# program instance (i.e. for the current token) separately
|
73
73
|
# ####################################################################
|
74
74
|
# left half of the head
|
75
|
-
first_half_q_offsets = (
|
76
|
-
|
77
|
-
)
|
78
|
-
|
79
|
-
|
80
|
-
)
|
81
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
82
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
83
|
-
)
|
84
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
85
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
86
|
-
)
|
87
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
88
|
-
sin_row.dtype
|
89
|
-
)
|
90
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
91
|
-
sin_row.dtype
|
92
|
-
)
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
93
81
|
|
94
82
|
# right half of the head
|
95
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
96
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
97
85
|
second_q_mask = first_q_mask
|
98
86
|
second_k_mask = first_k_mask
|
99
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
100
|
-
|
101
|
-
)
|
102
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
103
|
-
sin_row.dtype
|
104
|
-
)
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
105
89
|
|
106
90
|
if not BACKWARD_PASS:
|
107
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
liger_kernel/ops/swiglu.py
CHANGED
@@ -2,7 +2,8 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from liger_kernel.ops.utils import calculate_settings
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
6
7
|
|
7
8
|
|
8
9
|
@triton.jit
|
@@ -11,9 +12,7 @@ def silu(x):
|
|
11
12
|
|
12
13
|
|
13
14
|
@triton.jit
|
14
|
-
def _swiglu_forward_kernel(
|
15
|
-
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
16
|
-
):
|
15
|
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
17
16
|
program_id = tl.program_id(0).to(tl.int64)
|
18
17
|
|
19
18
|
# locate start index
|
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
|
|
32
31
|
|
33
32
|
|
34
33
|
@triton.jit
|
35
|
-
def _swiglu_backward_kernel(
|
36
|
-
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
37
|
-
):
|
34
|
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
38
35
|
program_id = tl.program_id(0).to(tl.int64)
|
39
36
|
|
40
37
|
# locate start index
|
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
|
|
84
81
|
|
85
82
|
|
86
83
|
def swiglu_backward(a, b, dc):
|
87
|
-
|
88
84
|
ori_shape = dc.shape
|
89
85
|
n_cols = ori_shape[-1]
|
90
86
|
dc = dc.view(-1, n_cols)
|
liger_kernel/ops/utils.py
CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
|
|
13
13
|
import functools
|
14
14
|
import importlib
|
15
15
|
import operator
|
16
|
+
|
16
17
|
from typing import Callable
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import triton
|
20
21
|
import triton.language as tl
|
22
|
+
|
21
23
|
from packaging.version import Version
|
22
24
|
|
23
25
|
from liger_kernel.utils import infer_device
|
@@ -1,31 +1,23 @@
|
|
1
|
-
from liger_kernel.transformers.auto_model import
|
2
|
-
AutoLigerKernelForCausalLM,
|
3
|
-
)
|
1
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
4
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
5
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
6
|
-
LigerFusedLinearCrossEntropyLoss,
|
7
|
-
)
|
3
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
8
4
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
9
5
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
10
6
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
11
7
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
12
|
-
from liger_kernel.transformers.monkey_patch import
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
apply_liger_kernel_to_qwen2_vl,
|
24
|
-
)
|
8
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
9
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
10
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
11
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
13
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
14
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
16
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
17
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
18
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
25
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
26
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
27
|
-
from liger_kernel.transformers.swiglu import
|
28
|
-
|
29
|
-
|
30
|
-
LigerSwiGLUMLP,
|
31
|
-
)
|
21
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
22
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
23
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
@@ -1,11 +1,10 @@
|
|
1
1
|
import inspect
|
2
2
|
|
3
|
-
from transformers import AutoConfig
|
3
|
+
from transformers import AutoConfig
|
4
|
+
from transformers import AutoModelForCausalLM
|
4
5
|
|
5
|
-
from liger_kernel.transformers.monkey_patch import
|
6
|
-
|
7
|
-
_apply_liger_kernel,
|
8
|
-
)
|
6
|
+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
7
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
9
8
|
|
10
9
|
|
11
10
|
def _get_model_config(model_dir, **model_init_kwargs):
|
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
34
33
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
35
34
|
apply_fn_signature = inspect.signature(apply_fn)
|
36
35
|
|
37
|
-
applicable_kwargs = {
|
38
|
-
key: value
|
39
|
-
for key, value in kwargs.items()
|
40
|
-
if key not in apply_fn_signature.parameters
|
41
|
-
}
|
36
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
42
37
|
|
43
|
-
return super().from_pretrained(
|
44
|
-
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
45
|
-
)
|
38
|
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
27
27
|
"sum",
|
28
28
|
"none",
|
29
29
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
30
|
-
assert
|
31
|
-
softcap is None or softcap > 0
|
32
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
30
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
33
31
|
self.ignore_index = ignore_index
|
34
32
|
self.lse_square_scale = lse_square_scale
|
35
33
|
self.label_smoothing = label_smoothing
|
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
7
7
|
|
8
8
|
|
9
9
|
class LigerEmbedding(nn.Module):
|
10
|
-
def __init__(
|
11
|
-
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
12
|
-
):
|
10
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
|
13
11
|
super().__init__()
|
14
12
|
self.num_embeddings = num_embeddings
|
15
13
|
self.embedding_dim = embedding_dim
|
@@ -1,9 +1,7 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
4
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
5
|
-
LigerFusedLinearCrossEntropyFunction,
|
6
|
-
)
|
4
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
7
5
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
8
6
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
9
7
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
@@ -159,9 +157,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
159
157
|
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
160
158
|
|
161
159
|
|
162
|
-
def liger_rms_norm(
|
163
|
-
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
|
164
|
-
):
|
160
|
+
def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
165
161
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
166
162
|
|
167
163
|
|