liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +21 -37
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/qwen2vl_mrope.py +2 -2
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.1.dist-info/RECORD +0 -65
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -4,12 +4,10 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
is_hip,
|
|
12
|
-
)
|
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
|
8
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
|
9
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
10
|
+
from liger_kernel.ops.utils import is_hip
|
|
13
11
|
|
|
14
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
15
13
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
|
|
|
43
41
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
44
42
|
|
|
45
43
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
46
|
-
chunk_size = triton.next_power_of_2(
|
|
47
|
-
triton.cdiv(BT, inc_factor)
|
|
48
|
-
) # (BT + inc_factor - 1) // inc_factor
|
|
44
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
|
49
45
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
50
46
|
|
|
51
|
-
grad_weight = (
|
|
52
|
-
torch.zeros_like(student_weight, device=device)
|
|
53
|
-
if student_weight.requires_grad
|
|
54
|
-
else None
|
|
55
|
-
)
|
|
47
|
+
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
|
|
56
48
|
grad_input = torch.zeros_like(student_input)
|
|
57
49
|
# we use fp32 for loss accumulator
|
|
58
50
|
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
|
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
|
|
|
73
65
|
# shape: chunk_size x V
|
|
74
66
|
# For anything starting from logits to the final JSD loss, we do computation
|
|
75
67
|
# in FP32 to avoid losing numerical stability.
|
|
76
|
-
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
|
77
|
-
|
|
78
|
-
)
|
|
79
|
-
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
|
80
|
-
torch.float32
|
|
81
|
-
)
|
|
68
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
|
|
69
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
|
|
82
70
|
chunk_n_rows = student_logits_chunk.shape[0]
|
|
83
71
|
|
|
84
72
|
# unreduced loss
|
|
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
|
|
|
104
92
|
dX_ptr=student_prob_chunk,
|
|
105
93
|
dX_stride=student_prob_chunk.stride(-2),
|
|
106
94
|
label_ptr=(
|
|
107
|
-
shift_labels[start_idx:end_idx]
|
|
108
|
-
if has_label
|
|
109
|
-
else torch.empty(1, device=device)
|
|
95
|
+
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
|
|
110
96
|
), # dummy ptr if no label
|
|
111
97
|
beta=jsd_beta,
|
|
112
98
|
n_non_ignore=n_non_ignore,
|
|
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
|
|
|
121
107
|
student_logits_chunk = (
|
|
122
108
|
student_prob_chunk
|
|
123
109
|
- torch.softmax(student_logits_chunk, dim=-1)
|
|
124
|
-
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
|
125
|
-
student_prob_chunk.shape
|
|
126
|
-
)
|
|
110
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
|
|
127
111
|
) / temperature
|
|
128
112
|
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
|
129
113
|
# w.r.t. `lm_head` which should be computed in original dtype
|
|
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
239
223
|
@amp_custom_bwd
|
|
240
224
|
def backward(ctx, grad_output):
|
|
241
225
|
(grad_input, grad_weight) = ctx.saved_tensors
|
|
242
|
-
grad_input, grad_weight = fused_linear_jsd_backward(
|
|
243
|
-
grad_output, grad_input, grad_weight
|
|
244
|
-
)
|
|
226
|
+
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
|
|
245
227
|
return (grad_input, grad_weight, None, None, None, None, None, None)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -4,11 +4,9 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
ensure_contiguous,
|
|
11
|
-
)
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
12
10
|
|
|
13
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
12
|
try:
|
|
@@ -22,9 +20,7 @@ else:
|
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
@triton.jit
|
|
25
|
-
def _geglu_tanh_forward_kernel(
|
|
26
|
-
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
27
|
-
):
|
|
23
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
28
24
|
program_id = tl.program_id(0).to(tl.int64)
|
|
29
25
|
|
|
30
26
|
# locate start index
|
|
@@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
49
45
|
|
|
50
46
|
|
|
51
47
|
@triton.jit
|
|
52
|
-
def _geglu_tanh_backward_kernel(
|
|
53
|
-
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
54
|
-
):
|
|
48
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
55
49
|
program_id = tl.program_id(0).to(tl.int64)
|
|
56
50
|
|
|
57
51
|
# locate start index
|
|
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
|
|
|
80
74
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
81
75
|
term1 = 0.5 * (1 + tanh_result)
|
|
82
76
|
tanh_sq = tanh_result * tanh_result
|
|
83
|
-
term2 = (
|
|
84
|
-
0.5
|
|
85
|
-
* a_row
|
|
86
|
-
* (1 - tanh_sq)
|
|
87
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
88
|
-
)
|
|
77
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
89
78
|
da_row = dc_row * b_row * (term1 + term2)
|
|
90
79
|
|
|
91
80
|
tl.store(a + col_offsets, da_row, mask=mask)
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -4,7 +4,8 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import compare_version
|
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
9
|
|
|
9
10
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
10
11
|
try:
|
|
@@ -73,9 +74,7 @@ def _group_norm_forward_kernel(
|
|
|
73
74
|
|
|
74
75
|
# Normalize
|
|
75
76
|
hidden_size_per_channel = hidden_size // channels_per_group
|
|
76
|
-
for channel_idx in tl.range(
|
|
77
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
78
|
-
):
|
|
77
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
79
78
|
W = tl.load(W_ptr + channel_idx)
|
|
80
79
|
B = tl.load(B_ptr + channel_idx)
|
|
81
80
|
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
@@ -132,21 +131,15 @@ def _group_norm_backward_kernel(
|
|
|
132
131
|
UPSTREAM_ptr += batch_idx * X_row_stride
|
|
133
132
|
|
|
134
133
|
# Mean and rstd are the same shape so have the same strides
|
|
135
|
-
mean = tl.load(
|
|
136
|
-
|
|
137
|
-
)
|
|
138
|
-
rstd = tl.load(
|
|
139
|
-
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
|
140
|
-
)
|
|
134
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
135
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
141
136
|
|
|
142
137
|
c1 = 0.0
|
|
143
138
|
c2 = 0.0
|
|
144
139
|
block_range = tl.arange(0, BLOCK_SIZE)
|
|
145
140
|
|
|
146
141
|
# We need to compute the sum terms of the backprop equations across all channels in the group
|
|
147
|
-
for channel_idx in range(
|
|
148
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
149
|
-
):
|
|
142
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
150
143
|
dW = 0.0
|
|
151
144
|
dB = 0.0
|
|
152
145
|
# Move the pointers to the correct channel
|
|
@@ -181,9 +174,7 @@ def _group_norm_backward_kernel(
|
|
|
181
174
|
c1 = c1 / N
|
|
182
175
|
c2 = c2 / N
|
|
183
176
|
|
|
184
|
-
for channel_idx in tl.range(
|
|
185
|
-
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
186
|
-
):
|
|
177
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
187
178
|
# Move the pointers to the correct channel
|
|
188
179
|
W = tl.load(W_ptr + channel_idx)
|
|
189
180
|
for i in range(0, hidden_size, BLOCK_SIZE):
|
|
@@ -203,9 +194,7 @@ def _group_norm_backward_kernel(
|
|
|
203
194
|
x_hat = (X - mean) * rstd
|
|
204
195
|
wdy = W * UPSTREAM_grad
|
|
205
196
|
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
206
|
-
tl.store(
|
|
207
|
-
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
|
208
|
-
)
|
|
197
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
|
209
198
|
|
|
210
199
|
|
|
211
200
|
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
@@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
|
216
205
|
X = X.view(batch_size, num_groups, -1).contiguous()
|
|
217
206
|
hidden_size = X.shape[-1]
|
|
218
207
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
219
|
-
Y = torch.empty(
|
|
220
|
-
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
|
221
|
-
)
|
|
208
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
|
222
209
|
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
223
210
|
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
224
211
|
|
|
@@ -307,16 +294,12 @@ class LigerGroupNormFunction(torch.autograd.Function):
|
|
|
307
294
|
)
|
|
308
295
|
ctx.num_channels = num_channels
|
|
309
296
|
ctx.num_groups = num_groups
|
|
310
|
-
ctx.save_for_backward(
|
|
311
|
-
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
|
312
|
-
)
|
|
297
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
|
313
298
|
return Y
|
|
314
299
|
|
|
315
300
|
@staticmethod
|
|
316
301
|
@ensure_contiguous
|
|
317
302
|
def backward(ctx, dY):
|
|
318
303
|
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
319
|
-
DX, DW, DB = group_norm_backward(
|
|
320
|
-
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
|
321
|
-
)
|
|
304
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
|
322
305
|
return DX, DW, DB, None, None, None
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
|
98
98
|
loss_stride=loss.stride(-2),
|
|
99
99
|
dX_ptr=dX,
|
|
100
100
|
dX_stride=dX.stride(-2),
|
|
101
|
-
label_ptr=(
|
|
102
|
-
shift_labels if has_label else torch.empty(1, device=_input.device)
|
|
103
|
-
), # dummy ptr if no label
|
|
101
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
|
104
102
|
beta=beta,
|
|
105
103
|
n_non_ignore=n_non_ignore,
|
|
106
104
|
ignore_index=ignore_index,
|
|
@@ -165,9 +163,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
165
163
|
shift_labels = shift_labels.contiguous()
|
|
166
164
|
has_label = True
|
|
167
165
|
|
|
168
|
-
loss, dX = jsd_forward(
|
|
169
|
-
_input, target, shift_labels, beta, ignore_index, has_label
|
|
170
|
-
)
|
|
166
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
|
171
167
|
ctx.save_for_backward(dX)
|
|
172
168
|
return loss
|
|
173
169
|
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -4,7 +4,8 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def get_num_warps(BLOCK_SIZE):
|
|
@@ -23,10 +24,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
|
23
24
|
|
|
24
25
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
25
26
|
|
|
26
|
-
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
27
|
-
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
28
|
-
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
29
|
-
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
27
|
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
|
28
|
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
|
29
|
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
|
30
|
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
|
30
31
|
|
|
31
32
|
_str_to_reduction_mode = {
|
|
32
33
|
"none": _REDUCTION_MODE_NONE.value,
|
|
@@ -218,9 +219,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
218
219
|
ctx.save_for_backward(y_true)
|
|
219
220
|
ctx.reduction = reduction
|
|
220
221
|
ctx.log_target = log_target
|
|
221
|
-
return kldiv_forward_triton(
|
|
222
|
-
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
|
223
|
-
)
|
|
222
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
|
224
223
|
|
|
225
224
|
@staticmethod
|
|
226
225
|
@ensure_contiguous
|
|
@@ -238,9 +237,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
238
237
|
|
|
239
238
|
new_grads = torch.empty_like(y_true)
|
|
240
239
|
|
|
241
|
-
derivative = kldiv_backward_triton(
|
|
242
|
-
y_true, grad_output, new_grads, ctx.log_target
|
|
243
|
-
)
|
|
240
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
|
244
241
|
|
|
245
242
|
if ctx.reduction == "batchmean":
|
|
246
243
|
derivative = derivative / y_true.shape[0]
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -5,11 +5,9 @@ import torch
|
|
|
5
5
|
import triton
|
|
6
6
|
import triton.language as tl
|
|
7
7
|
|
|
8
|
-
from liger_kernel.ops.utils import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
ensure_contiguous,
|
|
12
|
-
)
|
|
8
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
13
11
|
|
|
14
12
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
13
|
try:
|
|
@@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
|
|
|
10
10
|
cos,
|
|
11
11
|
sin,
|
|
12
12
|
sl,
|
|
13
|
+
bs: tl.constexpr,
|
|
13
14
|
n_qh: tl.constexpr,
|
|
14
15
|
n_kh: tl.constexpr,
|
|
15
16
|
hd: tl.constexpr,
|
|
@@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
|
|
|
41
42
|
t_end = mrope_section_t
|
|
42
43
|
h_end = t_end + mrope_section_h
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
w_sin = h_sin + sl * hd
|
|
45
|
+
t_cos = cos + pid * hd
|
|
46
|
+
h_cos = t_cos + bs * sl * hd
|
|
47
|
+
w_cos = h_cos + bs * sl * hd
|
|
48
|
+
t_sin = sin + pid * hd
|
|
49
|
+
h_sin = t_sin + bs * sl * hd
|
|
50
|
+
w_sin = h_sin + bs * sl * hd
|
|
51
51
|
|
|
52
52
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
53
53
|
t_mask = cos_offsets < t_end
|
|
@@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope(
|
|
|
67
67
|
# program instance (i.e. for the current token) separately
|
|
68
68
|
# ####################################################################
|
|
69
69
|
# left half of the head
|
|
70
|
-
first_half_q_offsets = (
|
|
71
|
-
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
)
|
|
76
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
77
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
78
|
-
)
|
|
79
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
80
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
81
|
-
)
|
|
82
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
83
|
-
sin_row.dtype
|
|
84
|
-
)
|
|
85
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
86
|
-
sin_row.dtype
|
|
87
|
-
)
|
|
70
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
71
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
72
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
74
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
|
75
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
|
88
76
|
|
|
89
77
|
# right half of the head
|
|
90
78
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
91
79
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
92
80
|
second_q_mask = first_q_mask
|
|
93
81
|
second_k_mask = first_k_mask
|
|
94
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
95
|
-
|
|
96
|
-
)
|
|
97
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
98
|
-
sin_row.dtype
|
|
99
|
-
)
|
|
82
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
|
83
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
|
100
84
|
|
|
101
85
|
if not BACKWARD_PASS:
|
|
102
86
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
@@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope(
|
|
|
124
108
|
|
|
125
109
|
|
|
126
110
|
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
127
|
-
|
|
128
111
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
129
112
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
130
113
|
q = q.transpose(1, 2)
|
|
@@ -151,6 +134,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
|
151
134
|
cos,
|
|
152
135
|
sin,
|
|
153
136
|
seq_len,
|
|
137
|
+
batch_size,
|
|
154
138
|
n_q_head,
|
|
155
139
|
n_kv_head,
|
|
156
140
|
head_dim,
|
|
@@ -189,6 +173,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
|
189
173
|
cos,
|
|
190
174
|
sin,
|
|
191
175
|
seq_len,
|
|
176
|
+
batch_size,
|
|
192
177
|
n_q_head,
|
|
193
178
|
n_kv_head,
|
|
194
179
|
head_dim,
|
|
@@ -216,8 +201,8 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
|
216
201
|
"""
|
|
217
202
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
218
203
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
219
|
-
cos size: (3,
|
|
220
|
-
sin size: (3,
|
|
204
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
205
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
221
206
|
"""
|
|
222
207
|
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
223
208
|
ctx.save_for_backward(cos, sin)
|
|
@@ -228,10 +213,9 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
|
228
213
|
"""
|
|
229
214
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
230
215
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
231
|
-
cos size: (3,
|
|
232
|
-
sin size: (3,
|
|
216
|
+
cos size: (3, bsz, seq_len, head_dim)
|
|
217
|
+
sin size: (3, bsz, seq_len, head_dim)
|
|
233
218
|
"""
|
|
234
|
-
|
|
235
219
|
cos, sin = ctx.saved_tensors
|
|
236
220
|
mrope_section = ctx.mrope_section
|
|
237
221
|
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -17,12 +17,10 @@ import torch
|
|
|
17
17
|
import triton
|
|
18
18
|
import triton.language as tl
|
|
19
19
|
|
|
20
|
-
from liger_kernel.ops.utils import
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
torch_to_triton_dtype,
|
|
25
|
-
)
|
|
20
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
21
|
+
from liger_kernel.ops.utils import compare_version
|
|
22
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
|
+
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
26
24
|
|
|
27
25
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
28
26
|
try:
|
|
@@ -35,9 +33,9 @@ else:
|
|
|
35
33
|
from triton.language.math import rsqrt
|
|
36
34
|
|
|
37
35
|
|
|
38
|
-
_CASTING_MODE_NONE = tl.constexpr(-1)
|
|
39
|
-
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
|
40
|
-
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
|
36
|
+
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
|
|
37
|
+
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
|
|
38
|
+
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
|
|
41
39
|
|
|
42
40
|
|
|
43
41
|
@triton.jit
|
|
@@ -177,9 +175,7 @@ def _rms_norm_backward_kernel(
|
|
|
177
175
|
|
|
178
176
|
dX_row = rstd_row * m
|
|
179
177
|
|
|
180
|
-
dX_row += (rstd_row) * (
|
|
181
|
-
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
182
|
-
)
|
|
178
|
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
183
179
|
|
|
184
180
|
# calculate the gradient of W
|
|
185
181
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -207,14 +203,10 @@ _str_to_casting_mode = {
|
|
|
207
203
|
|
|
208
204
|
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
209
205
|
if not isinstance(casting_mode, int):
|
|
210
|
-
assert
|
|
211
|
-
casting_mode in _str_to_casting_mode
|
|
212
|
-
), f"Invalid casting mode: {casting_mode}"
|
|
206
|
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
|
213
207
|
casting_mode = _str_to_casting_mode[casting_mode]
|
|
214
208
|
else:
|
|
215
|
-
assert (
|
|
216
|
-
casting_mode in _str_to_casting_mode.values()
|
|
217
|
-
), f"Invalid casting mode: {casting_mode}"
|
|
209
|
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
|
218
210
|
|
|
219
211
|
shape = X.shape
|
|
220
212
|
dim = shape[-1]
|
|
@@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
225
217
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
226
218
|
# RSTD is to cache rstd for each row
|
|
227
219
|
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
|
228
|
-
rstd_dtype = (
|
|
229
|
-
torch.float32
|
|
230
|
-
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
|
231
|
-
else X.dtype
|
|
232
|
-
)
|
|
220
|
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
233
221
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
234
222
|
|
|
235
223
|
# Check constraints.
|
|
236
|
-
assert
|
|
237
|
-
X.shape[1] == W.shape[0]
|
|
238
|
-
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
224
|
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
239
225
|
|
|
240
226
|
_rms_norm_forward_kernel[(n_rows,)](
|
|
241
227
|
Y,
|
|
@@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
256
242
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
257
243
|
|
|
258
244
|
|
|
259
|
-
def rms_norm_backward(
|
|
260
|
-
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
|
|
261
|
-
):
|
|
245
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
|
262
246
|
shape = dY.shape
|
|
263
247
|
dim = shape[-1]
|
|
264
248
|
dY = dY.view(-1, dim)
|
|
@@ -340,9 +324,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
340
324
|
X: (B, T, H) or (BxT, H)
|
|
341
325
|
W: (H,)
|
|
342
326
|
"""
|
|
343
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
|
344
|
-
X, W, eps, offset, casting_mode
|
|
345
|
-
)
|
|
327
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
|
346
328
|
ctx.offset = offset
|
|
347
329
|
ctx.casting_mode = casting_mode
|
|
348
330
|
ctx.in_place = in_place
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
|
15
15
|
sin_row_stride,
|
|
16
16
|
sl,
|
|
17
17
|
bs: tl.constexpr,
|
|
18
|
+
cos_bs: tl.constexpr,
|
|
18
19
|
n_qh: tl.constexpr,
|
|
19
20
|
n_kh: tl.constexpr,
|
|
20
21
|
hd: tl.constexpr,
|
|
@@ -29,7 +30,7 @@ def _triton_rope(
|
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
|
31
32
|
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
34
35
|
pid = tl.program_id(0)
|
|
35
36
|
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
|
48
49
|
# and pid % sl to get the sequence index.
|
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
50
51
|
# a clone of the left half.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
batch_idx = pid // sl
|
|
53
|
+
cos_row_idx = pid % sl
|
|
54
|
+
cos = cos + tl.where(
|
|
55
|
+
cos_bs == 1,
|
|
56
|
+
cos_row_idx * cos_row_stride,
|
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
|
58
|
+
)
|
|
59
|
+
sin = sin + tl.where(
|
|
60
|
+
cos_bs == 1,
|
|
61
|
+
cos_row_idx * sin_row_stride,
|
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
|
63
|
+
)
|
|
64
|
+
|
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
|
@@ -61,36 +72,20 @@ def _triton_rope(
|
|
|
61
72
|
# program instance (i.e. for the current token) separately
|
|
62
73
|
# ####################################################################
|
|
63
74
|
# left half of the head
|
|
64
|
-
first_half_q_offsets = (
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
70
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
71
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
72
|
-
)
|
|
73
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
74
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
75
|
-
)
|
|
76
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
77
|
-
sin_row.dtype
|
|
78
|
-
)
|
|
79
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
80
|
-
sin_row.dtype
|
|
81
|
-
)
|
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
|
82
81
|
|
|
83
82
|
# right half of the head
|
|
84
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
85
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
86
85
|
second_q_mask = first_q_mask
|
|
87
86
|
second_k_mask = first_k_mask
|
|
88
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
92
|
-
sin_row.dtype
|
|
93
|
-
)
|
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
|
94
89
|
|
|
95
90
|
if not BACKWARD_PASS:
|
|
96
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
@@ -118,7 +113,6 @@ def _triton_rope(
|
|
|
118
113
|
|
|
119
114
|
|
|
120
115
|
def rope_forward(q, k, cos, sin):
|
|
121
|
-
|
|
122
116
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
123
117
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
124
118
|
q = q.transpose(1, 2)
|
|
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
138
132
|
k = k.contiguous()
|
|
139
133
|
cos = cos.contiguous()
|
|
140
134
|
sin = sin.contiguous()
|
|
135
|
+
cos_batch_size = cos.shape[0]
|
|
141
136
|
|
|
142
137
|
_triton_rope[(n_row,)](
|
|
143
138
|
q,
|
|
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
150
145
|
sin.stride(-2),
|
|
151
146
|
seq_len,
|
|
152
147
|
batch_size,
|
|
148
|
+
cos_batch_size,
|
|
153
149
|
n_q_head,
|
|
154
150
|
n_kv_head,
|
|
155
151
|
head_dim,
|
|
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
167
163
|
dk = dk.transpose(1, 2)
|
|
168
164
|
|
|
169
165
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
166
|
+
cos_batch_size = cos.shape[0]
|
|
170
167
|
n_kv_head = dk.shape[2]
|
|
171
168
|
pad_hd = triton.next_power_of_2(head_dim)
|
|
172
169
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
191
188
|
sin.stride(-2),
|
|
192
189
|
seq_len,
|
|
193
190
|
batch_size,
|
|
191
|
+
cos_batch_size,
|
|
194
192
|
n_q_head,
|
|
195
193
|
n_kv_head,
|
|
196
194
|
head_dim,
|
|
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
221
219
|
"""
|
|
222
220
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
223
221
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
224
|
-
cos size: (1, seq_len, head_dim)
|
|
225
|
-
sin size: (1, seq_len, head_dim)
|
|
222
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
223
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
226
224
|
"""
|
|
227
225
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
228
226
|
ctx.save_for_backward(cos, sin)
|
|
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
232
230
|
"""
|
|
233
231
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
234
232
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
235
|
-
cos size: (1, seq_len, head_dim)
|
|
236
|
-
sin size: (1, seq_len, head_dim)
|
|
233
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
234
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
237
235
|
"""
|
|
238
236
|
|
|
239
237
|
cos, sin = ctx.saved_tensors
|