liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
@@ -4,11 +4,9 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
ensure_contiguous,
|
11
|
-
)
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
12
10
|
|
13
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
12
|
try:
|
@@ -22,9 +20,7 @@ else:
|
|
22
20
|
|
23
21
|
|
24
22
|
@triton.jit
|
25
|
-
def _geglu_tanh_forward_kernel(
|
26
|
-
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
27
|
-
):
|
23
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
28
24
|
program_id = tl.program_id(0).to(tl.int64)
|
29
25
|
|
30
26
|
# locate start index
|
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
|
|
49
45
|
|
50
46
|
|
51
47
|
@triton.jit
|
52
|
-
def _geglu_tanh_backward_kernel(
|
53
|
-
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
54
|
-
):
|
48
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
55
49
|
program_id = tl.program_id(0).to(tl.int64)
|
56
50
|
|
57
51
|
# locate start index
|
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
|
|
80
74
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
81
75
|
term1 = 0.5 * (1 + tanh_result)
|
82
76
|
tanh_sq = tanh_result * tanh_result
|
83
|
-
term2 = (
|
84
|
-
0.5
|
85
|
-
* a_row
|
86
|
-
* (1 - tanh_sq)
|
87
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
88
|
-
)
|
77
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
89
78
|
da_row = dc_row * b_row * (term1 + term2)
|
90
79
|
|
91
80
|
tl.store(a + col_offsets, da_row, mask=mask)
|
liger_kernel/ops/group_norm.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import compare_version
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
9
|
|
9
10
|
if compare_version("triton", operator.ge, "3.0.0"):
|
10
11
|
try:
|
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
|
|
73
74
|
|
74
75
|
# Normalize
|
75
76
|
hidden_size_per_channel = hidden_size // channels_per_group
|
76
|
-
for channel_idx in tl.range(
|
77
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
78
|
-
):
|
77
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
79
78
|
W = tl.load(W_ptr + channel_idx)
|
80
79
|
B = tl.load(B_ptr + channel_idx)
|
81
80
|
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
|
|
132
131
|
UPSTREAM_ptr += batch_idx * X_row_stride
|
133
132
|
|
134
133
|
# Mean and rstd are the same shape so have the same strides
|
135
|
-
mean = tl.load(
|
136
|
-
|
137
|
-
)
|
138
|
-
rstd = tl.load(
|
139
|
-
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
140
|
-
)
|
134
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
135
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
141
136
|
|
142
137
|
c1 = 0.0
|
143
138
|
c2 = 0.0
|
144
139
|
block_range = tl.arange(0, BLOCK_SIZE)
|
145
140
|
|
146
141
|
# We need to compute the sum terms of the backprop equations across all channels in the group
|
147
|
-
for channel_idx in range(
|
148
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
149
|
-
):
|
142
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
150
143
|
dW = 0.0
|
151
144
|
dB = 0.0
|
152
145
|
# Move the pointers to the correct channel
|
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
|
|
181
174
|
c1 = c1 / N
|
182
175
|
c2 = c2 / N
|
183
176
|
|
184
|
-
for channel_idx in tl.range(
|
185
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
186
|
-
):
|
177
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
187
178
|
# Move the pointers to the correct channel
|
188
179
|
W = tl.load(W_ptr + channel_idx)
|
189
180
|
for i in range(0, hidden_size, BLOCK_SIZE):
|
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
|
|
203
194
|
x_hat = (X - mean) * rstd
|
204
195
|
wdy = W * UPSTREAM_grad
|
205
196
|
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
206
|
-
tl.store(
|
207
|
-
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
208
|
-
)
|
197
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
209
198
|
|
210
199
|
|
211
200
|
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
216
205
|
X = X.view(batch_size, num_groups, -1).contiguous()
|
217
206
|
hidden_size = X.shape[-1]
|
218
207
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
219
|
-
Y = torch.empty(
|
220
|
-
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
221
|
-
)
|
208
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
222
209
|
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
223
210
|
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
224
211
|
|
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
|
|
307
294
|
)
|
308
295
|
ctx.num_channels = num_channels
|
309
296
|
ctx.num_groups = num_groups
|
310
|
-
ctx.save_for_backward(
|
311
|
-
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
312
|
-
)
|
297
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
313
298
|
return Y
|
314
299
|
|
315
300
|
@staticmethod
|
316
301
|
@ensure_contiguous
|
317
302
|
def backward(ctx, dY):
|
318
303
|
X, W, B, Mean, RSTD = ctx.saved_tensors
|
319
|
-
DX, DW, DB = group_norm_backward(
|
320
|
-
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
321
|
-
)
|
304
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
322
305
|
return DX, DW, DB, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
98
98
|
loss_stride=loss.stride(-2),
|
99
99
|
dX_ptr=dX,
|
100
100
|
dX_stride=dX.stride(-2),
|
101
|
-
label_ptr=(
|
102
|
-
shift_labels if has_label else torch.empty(1, device=_input.device)
|
103
|
-
), # dummy ptr if no label
|
101
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
104
102
|
beta=beta,
|
105
103
|
n_non_ignore=n_non_ignore,
|
106
104
|
ignore_index=ignore_index,
|
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
165
163
|
shift_labels = shift_labels.contiguous()
|
166
164
|
has_label = True
|
167
165
|
|
168
|
-
loss, dX = jsd_forward(
|
169
|
-
_input, target, shift_labels, beta, ignore_index, has_label
|
170
|
-
)
|
166
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
171
167
|
ctx.save_for_backward(dX)
|
172
168
|
return loss
|
173
169
|
|
liger_kernel/ops/kl_div.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
8
9
|
|
9
10
|
|
10
11
|
def get_num_warps(BLOCK_SIZE):
|
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
218
219
|
ctx.save_for_backward(y_true)
|
219
220
|
ctx.reduction = reduction
|
220
221
|
ctx.log_target = log_target
|
221
|
-
return kldiv_forward_triton(
|
222
|
-
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
223
|
-
)
|
222
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
224
223
|
|
225
224
|
@staticmethod
|
226
225
|
@ensure_contiguous
|
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
238
237
|
|
239
238
|
new_grads = torch.empty_like(y_true)
|
240
239
|
|
241
|
-
derivative = kldiv_backward_triton(
|
242
|
-
y_true, grad_output, new_grads, ctx.log_target
|
243
|
-
)
|
240
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
244
241
|
|
245
242
|
if ctx.reduction == "batchmean":
|
246
243
|
derivative = derivative / y_true.shape[0]
|
liger_kernel/ops/layer_norm.py
CHANGED
@@ -5,11 +5,9 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from liger_kernel.ops.utils import
|
9
|
-
|
10
|
-
|
11
|
-
ensure_contiguous,
|
12
|
-
)
|
8
|
+
from liger_kernel.ops.utils import calculate_settings
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
13
11
|
|
14
12
|
if compare_version("triton", operator.ge, "3.0.0"):
|
15
13
|
try:
|
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
|
|
67
67
|
# program instance (i.e. for the current token) separately
|
68
68
|
# ####################################################################
|
69
69
|
# left half of the head
|
70
|
-
first_half_q_offsets = (
|
71
|
-
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
)
|
76
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
77
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
78
|
-
)
|
79
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
80
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
81
|
-
)
|
82
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
83
|
-
sin_row.dtype
|
84
|
-
)
|
85
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
86
|
-
sin_row.dtype
|
87
|
-
)
|
70
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
71
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
72
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
74
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
75
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
88
76
|
|
89
77
|
# right half of the head
|
90
78
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
91
79
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
92
80
|
second_q_mask = first_q_mask
|
93
81
|
second_k_mask = first_k_mask
|
94
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
95
|
-
|
96
|
-
)
|
97
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
98
|
-
sin_row.dtype
|
99
|
-
)
|
82
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
83
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
100
84
|
|
101
85
|
if not BACKWARD_PASS:
|
102
86
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
|
|
124
108
|
|
125
109
|
|
126
110
|
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
127
|
-
|
128
111
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
129
112
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
130
113
|
q = q.transpose(1, 2)
|
liger_kernel/ops/rms_norm.py
CHANGED
@@ -17,12 +17,10 @@ import torch
|
|
17
17
|
import triton
|
18
18
|
import triton.language as tl
|
19
19
|
|
20
|
-
from liger_kernel.ops.utils import
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
torch_to_triton_dtype,
|
25
|
-
)
|
20
|
+
from liger_kernel.ops.utils import calculate_settings
|
21
|
+
from liger_kernel.ops.utils import compare_version
|
22
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
23
|
+
from liger_kernel.ops.utils import torch_to_triton_dtype
|
26
24
|
|
27
25
|
if compare_version("triton", operator.ge, "3.0.0"):
|
28
26
|
try:
|
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
|
|
177
175
|
|
178
176
|
dX_row = rstd_row * m
|
179
177
|
|
180
|
-
dX_row += (rstd_row) * (
|
181
|
-
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
182
|
-
)
|
178
|
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
183
179
|
|
184
180
|
# calculate the gradient of W
|
185
181
|
if casting_mode == _CASTING_MODE_LLAMA:
|
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
|
|
207
203
|
|
208
204
|
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
209
205
|
if not isinstance(casting_mode, int):
|
210
|
-
assert
|
211
|
-
casting_mode in _str_to_casting_mode
|
212
|
-
), f"Invalid casting mode: {casting_mode}"
|
206
|
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
213
207
|
casting_mode = _str_to_casting_mode[casting_mode]
|
214
208
|
else:
|
215
|
-
assert (
|
216
|
-
casting_mode in _str_to_casting_mode.values()
|
217
|
-
), f"Invalid casting mode: {casting_mode}"
|
209
|
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
218
210
|
|
219
211
|
shape = X.shape
|
220
212
|
dim = shape[-1]
|
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
225
217
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
226
218
|
# RSTD is to cache rstd for each row
|
227
219
|
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
228
|
-
rstd_dtype = (
|
229
|
-
torch.float32
|
230
|
-
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
231
|
-
else X.dtype
|
232
|
-
)
|
220
|
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
233
221
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
234
222
|
|
235
223
|
# Check constraints.
|
236
|
-
assert
|
237
|
-
X.shape[1] == W.shape[0]
|
238
|
-
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
224
|
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
239
225
|
|
240
226
|
_rms_norm_forward_kernel[(n_rows,)](
|
241
227
|
Y,
|
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
256
242
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
257
243
|
|
258
244
|
|
259
|
-
def rms_norm_backward(
|
260
|
-
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
|
261
|
-
):
|
245
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
262
246
|
shape = dY.shape
|
263
247
|
dim = shape[-1]
|
264
248
|
dY = dY.view(-1, dim)
|
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
340
324
|
X: (B, T, H) or (BxT, H)
|
341
325
|
W: (H,)
|
342
326
|
"""
|
343
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
344
|
-
X, W, eps, offset, casting_mode
|
345
|
-
)
|
327
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
346
328
|
ctx.offset = offset
|
347
329
|
ctx.casting_mode = casting_mode
|
348
330
|
ctx.in_place = in_place
|
liger_kernel/ops/rope.py
CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
15
15
|
sin_row_stride,
|
16
16
|
sl,
|
17
17
|
bs: tl.constexpr,
|
18
|
+
cos_bs: tl.constexpr,
|
18
19
|
n_qh: tl.constexpr,
|
19
20
|
n_kh: tl.constexpr,
|
20
21
|
hd: tl.constexpr,
|
@@ -29,7 +30,7 @@ def _triton_rope(
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
31
32
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
34
35
|
pid = tl.program_id(0)
|
35
36
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
48
49
|
# and pid % sl to get the sequence index.
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
50
51
|
# a clone of the left half.
|
51
|
-
|
52
|
-
|
53
|
-
|
52
|
+
batch_idx = pid // sl
|
53
|
+
cos_row_idx = pid % sl
|
54
|
+
cos = cos + tl.where(
|
55
|
+
cos_bs == 1,
|
56
|
+
cos_row_idx * cos_row_stride,
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
58
|
+
)
|
59
|
+
sin = sin + tl.where(
|
60
|
+
cos_bs == 1,
|
61
|
+
cos_row_idx * sin_row_stride,
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
63
|
+
)
|
64
|
+
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
@@ -61,36 +72,20 @@ def _triton_rope(
|
|
61
72
|
# program instance (i.e. for the current token) separately
|
62
73
|
# ####################################################################
|
63
74
|
# left half of the head
|
64
|
-
first_half_q_offsets = (
|
65
|
-
|
66
|
-
)
|
67
|
-
|
68
|
-
|
69
|
-
)
|
70
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
71
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
72
|
-
)
|
73
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
74
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
75
|
-
)
|
76
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
77
|
-
sin_row.dtype
|
78
|
-
)
|
79
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
80
|
-
sin_row.dtype
|
81
|
-
)
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
82
81
|
|
83
82
|
# right half of the head
|
84
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
85
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
86
85
|
second_q_mask = first_q_mask
|
87
86
|
second_k_mask = first_k_mask
|
88
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
89
|
-
|
90
|
-
)
|
91
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
92
|
-
sin_row.dtype
|
93
|
-
)
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
94
89
|
|
95
90
|
if not BACKWARD_PASS:
|
96
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
@@ -118,7 +113,6 @@ def _triton_rope(
|
|
118
113
|
|
119
114
|
|
120
115
|
def rope_forward(q, k, cos, sin):
|
121
|
-
|
122
116
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
123
117
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
124
118
|
q = q.transpose(1, 2)
|
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
|
|
138
132
|
k = k.contiguous()
|
139
133
|
cos = cos.contiguous()
|
140
134
|
sin = sin.contiguous()
|
135
|
+
cos_batch_size = cos.shape[0]
|
141
136
|
|
142
137
|
_triton_rope[(n_row,)](
|
143
138
|
q,
|
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
|
|
150
145
|
sin.stride(-2),
|
151
146
|
seq_len,
|
152
147
|
batch_size,
|
148
|
+
cos_batch_size,
|
153
149
|
n_q_head,
|
154
150
|
n_kv_head,
|
155
151
|
head_dim,
|
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
167
163
|
dk = dk.transpose(1, 2)
|
168
164
|
|
169
165
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
166
|
+
cos_batch_size = cos.shape[0]
|
170
167
|
n_kv_head = dk.shape[2]
|
171
168
|
pad_hd = triton.next_power_of_2(head_dim)
|
172
169
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
191
188
|
sin.stride(-2),
|
192
189
|
seq_len,
|
193
190
|
batch_size,
|
191
|
+
cos_batch_size,
|
194
192
|
n_q_head,
|
195
193
|
n_kv_head,
|
196
194
|
head_dim,
|
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
221
219
|
"""
|
222
220
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
223
221
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
224
|
-
cos size: (1, seq_len, head_dim)
|
225
|
-
sin size: (1, seq_len, head_dim)
|
222
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
223
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
226
224
|
"""
|
227
225
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
228
226
|
ctx.save_for_backward(cos, sin)
|
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
232
230
|
"""
|
233
231
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
234
232
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
235
|
-
cos size: (1, seq_len, head_dim)
|
236
|
-
sin size: (1, seq_len, head_dim)
|
233
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
234
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
237
235
|
"""
|
238
236
|
|
239
237
|
cos, sin = ctx.saved_tensors
|
liger_kernel/ops/swiglu.py
CHANGED
@@ -2,7 +2,8 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from liger_kernel.ops.utils import calculate_settings
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
6
7
|
|
7
8
|
|
8
9
|
@triton.jit
|
@@ -11,9 +12,7 @@ def silu(x):
|
|
11
12
|
|
12
13
|
|
13
14
|
@triton.jit
|
14
|
-
def _swiglu_forward_kernel(
|
15
|
-
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
16
|
-
):
|
15
|
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
17
16
|
program_id = tl.program_id(0).to(tl.int64)
|
18
17
|
|
19
18
|
# locate start index
|
@@ -32,9 +31,7 @@ def _swiglu_forward_kernel(
|
|
32
31
|
|
33
32
|
|
34
33
|
@triton.jit
|
35
|
-
def _swiglu_backward_kernel(
|
36
|
-
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
37
|
-
):
|
34
|
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
38
35
|
program_id = tl.program_id(0).to(tl.int64)
|
39
36
|
|
40
37
|
# locate start index
|
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
|
|
84
81
|
|
85
82
|
|
86
83
|
def swiglu_backward(a, b, dc):
|
87
|
-
|
88
84
|
ori_shape = dc.shape
|
89
85
|
n_cols = ori_shape[-1]
|
90
86
|
dc = dc.view(-1, n_cols)
|
liger_kernel/ops/utils.py
CHANGED
@@ -13,11 +13,13 @@ Modifications made by Yanning Chen, 2024.
|
|
13
13
|
import functools
|
14
14
|
import importlib
|
15
15
|
import operator
|
16
|
+
|
16
17
|
from typing import Callable
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import triton
|
20
21
|
import triton.language as tl
|
22
|
+
|
21
23
|
from packaging.version import Version
|
22
24
|
|
23
25
|
from liger_kernel.utils import infer_device
|
@@ -1,31 +1,23 @@
|
|
1
|
-
from liger_kernel.transformers.auto_model import
|
2
|
-
AutoLigerKernelForCausalLM,
|
3
|
-
)
|
1
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
4
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
5
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import
|
6
|
-
LigerFusedLinearCrossEntropyLoss,
|
7
|
-
)
|
3
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
8
4
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
9
5
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
10
6
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
11
7
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
12
|
-
from liger_kernel.transformers.monkey_patch import
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
apply_liger_kernel_to_qwen2_vl,
|
24
|
-
)
|
8
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
9
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
10
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
11
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
12
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
13
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
14
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
16
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
17
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
18
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
25
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
26
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
27
|
-
from liger_kernel.transformers.swiglu import
|
28
|
-
|
29
|
-
|
30
|
-
LigerSwiGLUMLP,
|
31
|
-
)
|
21
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
22
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
23
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
@@ -1,11 +1,10 @@
|
|
1
1
|
import inspect
|
2
2
|
|
3
|
-
from transformers import AutoConfig
|
3
|
+
from transformers import AutoConfig
|
4
|
+
from transformers import AutoModelForCausalLM
|
4
5
|
|
5
|
-
from liger_kernel.transformers.monkey_patch import
|
6
|
-
|
7
|
-
_apply_liger_kernel,
|
8
|
-
)
|
6
|
+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
7
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
9
8
|
|
10
9
|
|
11
10
|
def _get_model_config(model_dir, **model_init_kwargs):
|
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
34
33
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
35
34
|
apply_fn_signature = inspect.signature(apply_fn)
|
36
35
|
|
37
|
-
applicable_kwargs = {
|
38
|
-
key: value
|
39
|
-
for key, value in kwargs.items()
|
40
|
-
if key not in apply_fn_signature.parameters
|
41
|
-
}
|
36
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
42
37
|
|
43
|
-
return super().from_pretrained(
|
44
|
-
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
45
|
-
)
|
38
|
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
@@ -27,9 +27,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
27
27
|
"sum",
|
28
28
|
"none",
|
29
29
|
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
30
|
-
assert
|
31
|
-
softcap is None or softcap > 0
|
32
|
-
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
30
|
+
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
33
31
|
self.ignore_index = ignore_index
|
34
32
|
self.lse_square_scale = lse_square_scale
|
35
33
|
self.label_smoothing = label_smoothing
|
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
|
|
7
7
|
|
8
8
|
|
9
9
|
class LigerEmbedding(nn.Module):
|
10
|
-
def __init__(
|
11
|
-
self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
|
12
|
-
):
|
10
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
|
13
11
|
super().__init__()
|
14
12
|
self.num_embeddings = num_embeddings
|
15
13
|
self.embedding_dim = embedding_dim
|