liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -12
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,10 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
is_hip,
|
12
|
-
)
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
8
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
9
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
10
|
+
from liger_kernel.ops.utils import is_hip
|
13
11
|
|
14
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
15
13
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
|
|
43
41
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
44
42
|
|
45
43
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
46
|
-
chunk_size = triton.next_power_of_2(
|
47
|
-
triton.cdiv(BT, inc_factor)
|
48
|
-
) # (BT + inc_factor - 1) // inc_factor
|
44
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
49
45
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
50
46
|
|
51
|
-
grad_weight = (
|
52
|
-
torch.zeros_like(student_weight, device=device)
|
53
|
-
if student_weight.requires_grad
|
54
|
-
else None
|
55
|
-
)
|
47
|
+
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
|
56
48
|
grad_input = torch.zeros_like(student_input)
|
57
49
|
# we use fp32 for loss accumulator
|
58
50
|
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
|
|
73
65
|
# shape: chunk_size x V
|
74
66
|
# For anything starting from logits to the final JSD loss, we do computation
|
75
67
|
# in FP32 to avoid losing numerical stability.
|
76
|
-
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
77
|
-
|
78
|
-
)
|
79
|
-
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
80
|
-
torch.float32
|
81
|
-
)
|
68
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
|
69
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
|
82
70
|
chunk_n_rows = student_logits_chunk.shape[0]
|
83
71
|
|
84
72
|
# unreduced loss
|
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
|
|
104
92
|
dX_ptr=student_prob_chunk,
|
105
93
|
dX_stride=student_prob_chunk.stride(-2),
|
106
94
|
label_ptr=(
|
107
|
-
shift_labels[start_idx:end_idx]
|
108
|
-
if has_label
|
109
|
-
else torch.empty(1, device=device)
|
95
|
+
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
|
110
96
|
), # dummy ptr if no label
|
111
97
|
beta=jsd_beta,
|
112
98
|
n_non_ignore=n_non_ignore,
|
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
|
|
121
107
|
student_logits_chunk = (
|
122
108
|
student_prob_chunk
|
123
109
|
- torch.softmax(student_logits_chunk, dim=-1)
|
124
|
-
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
125
|
-
student_prob_chunk.shape
|
126
|
-
)
|
110
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
|
127
111
|
) / temperature
|
128
112
|
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
129
113
|
# w.r.t. `lm_head` which should be computed in original dtype
|
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
239
223
|
@amp_custom_bwd
|
240
224
|
def backward(ctx, grad_output):
|
241
225
|
(grad_input, grad_weight) = ctx.saved_tensors
|
242
|
-
grad_input, grad_weight = fused_linear_jsd_backward(
|
243
|
-
grad_output, grad_input, grad_weight
|
244
|
-
)
|
226
|
+
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
|
245
227
|
return (grad_input, grad_weight, None, None, None, None, None, None)
|
liger_kernel/ops/geglu.py
CHANGED
@@ -4,11 +4,9 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
ensure_contiguous,
|
11
|
-
)
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
12
10
|
|
13
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
12
|
try:
|
@@ -22,9 +20,7 @@ else:
|
|
22
20
|
|
23
21
|
|
24
22
|
@triton.jit
|
25
|
-
def _geglu_tanh_forward_kernel(
|
26
|
-
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
27
|
-
):
|
23
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
28
24
|
program_id = tl.program_id(0).to(tl.int64)
|
29
25
|
|
30
26
|
# locate start index
|
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
|
|
49
45
|
|
50
46
|
|
51
47
|
@triton.jit
|
52
|
-
def _geglu_tanh_backward_kernel(
|
53
|
-
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
54
|
-
):
|
48
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
55
49
|
program_id = tl.program_id(0).to(tl.int64)
|
56
50
|
|
57
51
|
# locate start index
|
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
|
|
80
74
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
81
75
|
term1 = 0.5 * (1 + tanh_result)
|
82
76
|
tanh_sq = tanh_result * tanh_result
|
83
|
-
term2 = (
|
84
|
-
0.5
|
85
|
-
* a_row
|
86
|
-
* (1 - tanh_sq)
|
87
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
88
|
-
)
|
77
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
89
78
|
da_row = dc_row * b_row * (term1 + term2)
|
90
79
|
|
91
80
|
tl.store(a + col_offsets, da_row, mask=mask)
|
liger_kernel/ops/group_norm.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import compare_version
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
9
|
|
9
10
|
if compare_version("triton", operator.ge, "3.0.0"):
|
10
11
|
try:
|
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
|
|
73
74
|
|
74
75
|
# Normalize
|
75
76
|
hidden_size_per_channel = hidden_size // channels_per_group
|
76
|
-
for channel_idx in tl.range(
|
77
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
78
|
-
):
|
77
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
79
78
|
W = tl.load(W_ptr + channel_idx)
|
80
79
|
B = tl.load(B_ptr + channel_idx)
|
81
80
|
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
|
|
132
131
|
UPSTREAM_ptr += batch_idx * X_row_stride
|
133
132
|
|
134
133
|
# Mean and rstd are the same shape so have the same strides
|
135
|
-
mean = tl.load(
|
136
|
-
|
137
|
-
)
|
138
|
-
rstd = tl.load(
|
139
|
-
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
140
|
-
)
|
134
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
135
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
141
136
|
|
142
137
|
c1 = 0.0
|
143
138
|
c2 = 0.0
|
144
139
|
block_range = tl.arange(0, BLOCK_SIZE)
|
145
140
|
|
146
141
|
# We need to compute the sum terms of the backprop equations across all channels in the group
|
147
|
-
for channel_idx in range(
|
148
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
149
|
-
):
|
142
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
150
143
|
dW = 0.0
|
151
144
|
dB = 0.0
|
152
145
|
# Move the pointers to the correct channel
|
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
|
|
181
174
|
c1 = c1 / N
|
182
175
|
c2 = c2 / N
|
183
176
|
|
184
|
-
for channel_idx in tl.range(
|
185
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
186
|
-
):
|
177
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
187
178
|
# Move the pointers to the correct channel
|
188
179
|
W = tl.load(W_ptr + channel_idx)
|
189
180
|
for i in range(0, hidden_size, BLOCK_SIZE):
|
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
|
|
203
194
|
x_hat = (X - mean) * rstd
|
204
195
|
wdy = W * UPSTREAM_grad
|
205
196
|
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
206
|
-
tl.store(
|
207
|
-
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
208
|
-
)
|
197
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
209
198
|
|
210
199
|
|
211
200
|
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
216
205
|
X = X.view(batch_size, num_groups, -1).contiguous()
|
217
206
|
hidden_size = X.shape[-1]
|
218
207
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
219
|
-
Y = torch.empty(
|
220
|
-
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
221
|
-
)
|
208
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
222
209
|
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
223
210
|
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
224
211
|
|
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
|
|
307
294
|
)
|
308
295
|
ctx.num_channels = num_channels
|
309
296
|
ctx.num_groups = num_groups
|
310
|
-
ctx.save_for_backward(
|
311
|
-
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
312
|
-
)
|
297
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
313
298
|
return Y
|
314
299
|
|
315
300
|
@staticmethod
|
316
301
|
@ensure_contiguous
|
317
302
|
def backward(ctx, dY):
|
318
303
|
X, W, B, Mean, RSTD = ctx.saved_tensors
|
319
|
-
DX, DW, DB = group_norm_backward(
|
320
|
-
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
321
|
-
)
|
304
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
322
305
|
return DX, DW, DB, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
98
98
|
loss_stride=loss.stride(-2),
|
99
99
|
dX_ptr=dX,
|
100
100
|
dX_stride=dX.stride(-2),
|
101
|
-
label_ptr=(
|
102
|
-
shift_labels if has_label else torch.empty(1, device=_input.device)
|
103
|
-
), # dummy ptr if no label
|
101
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
104
102
|
beta=beta,
|
105
103
|
n_non_ignore=n_non_ignore,
|
106
104
|
ignore_index=ignore_index,
|
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
165
163
|
shift_labels = shift_labels.contiguous()
|
166
164
|
has_label = True
|
167
165
|
|
168
|
-
loss, dX = jsd_forward(
|
169
|
-
_input, target, shift_labels, beta, ignore_index, has_label
|
170
|
-
)
|
166
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
171
167
|
ctx.save_for_backward(dX)
|
172
168
|
return loss
|
173
169
|
|
liger_kernel/ops/kl_div.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
8
9
|
|
9
10
|
|
10
11
|
def get_num_warps(BLOCK_SIZE):
|
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
218
219
|
ctx.save_for_backward(y_true)
|
219
220
|
ctx.reduction = reduction
|
220
221
|
ctx.log_target = log_target
|
221
|
-
return kldiv_forward_triton(
|
222
|
-
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
223
|
-
)
|
222
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
224
223
|
|
225
224
|
@staticmethod
|
226
225
|
@ensure_contiguous
|
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
238
237
|
|
239
238
|
new_grads = torch.empty_like(y_true)
|
240
239
|
|
241
|
-
derivative = kldiv_backward_triton(
|
242
|
-
y_true, grad_output, new_grads, ctx.log_target
|
243
|
-
)
|
240
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
244
241
|
|
245
242
|
if ctx.reduction == "batchmean":
|
246
243
|
derivative = derivative / y_true.shape[0]
|
liger_kernel/ops/layer_norm.py
CHANGED
@@ -5,11 +5,9 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from liger_kernel.ops.utils import
|
9
|
-
|
10
|
-
|
11
|
-
ensure_contiguous,
|
12
|
-
)
|
8
|
+
from liger_kernel.ops.utils import calculate_settings
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
13
11
|
|
14
12
|
if compare_version("triton", operator.ge, "3.0.0"):
|
15
13
|
try:
|
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
|
|
67
67
|
# program instance (i.e. for the current token) separately
|
68
68
|
# ####################################################################
|
69
69
|
# left half of the head
|
70
|
-
first_half_q_offsets = (
|
71
|
-
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
)
|
76
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
77
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
78
|
-
)
|
79
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
80
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
81
|
-
)
|
82
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
83
|
-
sin_row.dtype
|
84
|
-
)
|
85
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
86
|
-
sin_row.dtype
|
87
|
-
)
|
70
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
71
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
72
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
74
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
75
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
88
76
|
|
89
77
|
# right half of the head
|
90
78
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
91
79
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
92
80
|
second_q_mask = first_q_mask
|
93
81
|
second_k_mask = first_k_mask
|
94
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
95
|
-
|
96
|
-
)
|
97
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
98
|
-
sin_row.dtype
|
99
|
-
)
|
82
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
83
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
100
84
|
|
101
85
|
if not BACKWARD_PASS:
|
102
86
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
|
|
124
108
|
|
125
109
|
|
126
110
|
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
127
|
-
|
128
111
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
129
112
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
130
113
|
q = q.transpose(1, 2)
|
liger_kernel/ops/rms_norm.py
CHANGED
@@ -17,12 +17,10 @@ import torch
|
|
17
17
|
import triton
|
18
18
|
import triton.language as tl
|
19
19
|
|
20
|
-
from liger_kernel.ops.utils import
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
torch_to_triton_dtype,
|
25
|
-
)
|
20
|
+
from liger_kernel.ops.utils import calculate_settings
|
21
|
+
from liger_kernel.ops.utils import compare_version
|
22
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
23
|
+
from liger_kernel.ops.utils import torch_to_triton_dtype
|
26
24
|
|
27
25
|
if compare_version("triton", operator.ge, "3.0.0"):
|
28
26
|
try:
|
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
|
|
177
175
|
|
178
176
|
dX_row = rstd_row * m
|
179
177
|
|
180
|
-
dX_row += (rstd_row) * (
|
181
|
-
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
182
|
-
)
|
178
|
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
183
179
|
|
184
180
|
# calculate the gradient of W
|
185
181
|
if casting_mode == _CASTING_MODE_LLAMA:
|
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
|
|
207
203
|
|
208
204
|
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
209
205
|
if not isinstance(casting_mode, int):
|
210
|
-
assert
|
211
|
-
casting_mode in _str_to_casting_mode
|
212
|
-
), f"Invalid casting mode: {casting_mode}"
|
206
|
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
213
207
|
casting_mode = _str_to_casting_mode[casting_mode]
|
214
208
|
else:
|
215
|
-
assert (
|
216
|
-
casting_mode in _str_to_casting_mode.values()
|
217
|
-
), f"Invalid casting mode: {casting_mode}"
|
209
|
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
218
210
|
|
219
211
|
shape = X.shape
|
220
212
|
dim = shape[-1]
|
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
225
217
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
226
218
|
# RSTD is to cache rstd for each row
|
227
219
|
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
228
|
-
rstd_dtype = (
|
229
|
-
torch.float32
|
230
|
-
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
231
|
-
else X.dtype
|
232
|
-
)
|
220
|
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
233
221
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
234
222
|
|
235
223
|
# Check constraints.
|
236
|
-
assert
|
237
|
-
X.shape[1] == W.shape[0]
|
238
|
-
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
224
|
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
239
225
|
|
240
226
|
_rms_norm_forward_kernel[(n_rows,)](
|
241
227
|
Y,
|
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
256
242
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
257
243
|
|
258
244
|
|
259
|
-
def rms_norm_backward(
|
260
|
-
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
|
261
|
-
):
|
245
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
262
246
|
shape = dY.shape
|
263
247
|
dim = shape[-1]
|
264
248
|
dY = dY.view(-1, dim)
|
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
340
324
|
X: (B, T, H) or (BxT, H)
|
341
325
|
W: (H,)
|
342
326
|
"""
|
343
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
344
|
-
X, W, eps, offset, casting_mode
|
345
|
-
)
|
327
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
346
328
|
ctx.offset = offset
|
347
329
|
ctx.casting_mode = casting_mode
|
348
330
|
ctx.in_place = in_place
|
liger_kernel/ops/rope.py
CHANGED
@@ -72,36 +72,20 @@ def _triton_rope(
|
|
72
72
|
# program instance (i.e. for the current token) separately
|
73
73
|
# ####################################################################
|
74
74
|
# left half of the head
|
75
|
-
first_half_q_offsets = (
|
76
|
-
|
77
|
-
)
|
78
|
-
|
79
|
-
|
80
|
-
)
|
81
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
82
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
83
|
-
)
|
84
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
85
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
86
|
-
)
|
87
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
88
|
-
sin_row.dtype
|
89
|
-
)
|
90
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
91
|
-
sin_row.dtype
|
92
|
-
)
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
93
81
|
|
94
82
|
# right half of the head
|
95
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
96
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
97
85
|
second_q_mask = first_q_mask
|
98
86
|
second_k_mask = first_k_mask
|
99
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
100
|
-
|
101
|
-
)
|
102
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
103
|
-
sin_row.dtype
|
104
|
-
)
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
105
89
|
|
106
90
|
if not BACKWARD_PASS:
|
107
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
liger_kernel/ops/swiglu.py
CHANGED
@@ -2,7 +2,8 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from liger_kernel.ops.utils import calculate_settings
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
6
7
|
|
7
8
|
|
8
9
|
@triton.jit
|
@@ -11,9 +12,7 @@ def silu(x):
|
|
11
12
|
|
12
13
|
|
13
14
|
@triton.jit
|
14
|
-
def _swiglu_forward_kernel(
|
15
|
-
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
16
|
-
):
|
15
|
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
17
16
|
program_id = tl.program_id(0).to(tl.int64)
|
18
17
|
|
19
18
|
# locate start index
|
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
|
|
32
31
|
|
33
32
|
|
34
33
|
@triton.jit
|
35
|
-
def _swiglu_backward_kernel(
|
36
|
-
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
37
|
-
):
|
34
|
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
38
35
|
program_id = tl.program_id(0).to(tl.int64)
|
39
36
|
|
40
37
|
# locate start index
|
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
|
|
84
81
|
|
85
82
|
|
86
83
|
def swiglu_backward(a, b, dc):
|
87
|
-
|
88
84
|
ori_shape = dc.shape
|
89
85
|
n_cols = ori_shape[-1]
|
90
86
|
dc = dc.view(-1, n_cols)
|
liger_kernel/ops/utils.py
CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
|
|
13
13
|
import functools
|
14
14
|
import importlib
|
15
15
|
import operator
|
16
|
+
|
16
17
|
from typing import Callable
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import triton
|
20
21
|
import triton.language as tl
|
22
|
+
|
21
23
|
from packaging.version import Version
|
22
24
|
|
23
25
|
from liger_kernel.utils import infer_device
|
@@ -1,31 +1,23 @@
|
|
1
|
-
from liger_kernel.transformers.auto_model import
|
2
|
-
AutoLigerKernelForCausalLM,
|
3
|
-
)
|
1
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
4
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
5
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
6
|
-
LigerFusedLinearCrossEntropyLoss,
|
7
|
-
)
|
3
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
8
4
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
9
5
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
10
6
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
11
7
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
12
|
-
from liger_kernel.transformers.monkey_patch import
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
apply_liger_kernel_to_qwen2_vl,
|
24
|
-
)
|
8
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
9
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
10
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
11
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
13
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
14
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
16
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
17
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
18
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
25
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
26
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
27
|
-
from liger_kernel.transformers.swiglu import
|
28
|
-
|
29
|
-
|
30
|
-
LigerSwiGLUMLP,
|
31
|
-
)
|
21
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
22
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
23
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
@@ -1,11 +1,10 @@
|
|
1
1
|
import inspect
|
2
2
|
|
3
|
-
from transformers import AutoConfig
|
3
|
+
from transformers import AutoConfig
|
4
|
+
from transformers import AutoModelForCausalLM
|
4
5
|
|
5
|
-
from liger_kernel.transformers.monkey_patch import
|
6
|
-
|
7
|
-
_apply_liger_kernel,
|
8
|
-
)
|
6
|
+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
7
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
9
8
|
|
10
9
|
|
11
10
|
def _get_model_config(model_dir, **model_init_kwargs):
|
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
34
33
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
35
34
|
apply_fn_signature = inspect.signature(apply_fn)
|
36
35
|
|
37
|
-
applicable_kwargs = {
|
38
|
-
key: value
|
39
|
-
for key, value in kwargs.items()
|
40
|
-
if key not in apply_fn_signature.parameters
|
41
|
-
}
|
36
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
42
37
|
|
43
|
-
return super().from_pretrained(
|
44
|
-
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
45
|
-
)
|
38
|
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
27
27
|
"sum",
|
28
28
|
"none",
|
29
29
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
30
|
-
assert
|
31
|
-
softcap is None or softcap > 0
|
32
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
30
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
33
31
|
self.ignore_index = ignore_index
|
34
32
|
self.lse_square_scale = lse_square_scale
|
35
33
|
self.label_smoothing = label_smoothing
|
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
7
7
|
|
8
8
|
|
9
9
|
class LigerEmbedding(nn.Module):
|
10
|
-
def __init__(
|
11
|
-
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
12
|
-
):
|
10
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
|
13
11
|
super().__init__()
|
14
12
|
self.num_embeddings = num_embeddings
|
15
13
|
self.embedding_dim = embedding_dim
|
@@ -1,9 +1,7 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
4
|
-
from liger_kernel.ops.fused_linear_cross_entropy import
|
5
|
-
LigerFusedLinearCrossEntropyFunction,
|
6
|
-
)
|
4
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
7
5
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
8
6
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
9
7
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
@@ -159,9 +157,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
159
157
|
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
160
158
|
|
161
159
|
|
162
|
-
def liger_rms_norm(
|
163
|
-
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
|
164
|
-
):
|
160
|
+
def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
165
161
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
166
162
|
|
167
163
|
|