liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223032630__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/ops/rope.py +23 -9
- liger_kernel/transformers/rope.py +2 -2
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rope.py
CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
15
15
|
sin_row_stride,
|
16
16
|
sl,
|
17
17
|
bs: tl.constexpr,
|
18
|
+
cos_bs: tl.constexpr,
|
18
19
|
n_qh: tl.constexpr,
|
19
20
|
n_kh: tl.constexpr,
|
20
21
|
hd: tl.constexpr,
|
@@ -29,7 +30,7 @@ def _triton_rope(
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
31
32
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
34
35
|
pid = tl.program_id(0)
|
35
36
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
48
49
|
# and pid % sl to get the sequence index.
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
50
51
|
# a clone of the left half.
|
51
|
-
|
52
|
-
|
53
|
-
|
52
|
+
batch_idx = pid // sl
|
53
|
+
cos_row_idx = pid % sl
|
54
|
+
cos = cos + tl.where(
|
55
|
+
cos_bs == 1,
|
56
|
+
cos_row_idx * cos_row_stride,
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
58
|
+
)
|
59
|
+
sin = sin + tl.where(
|
60
|
+
cos_bs == 1,
|
61
|
+
cos_row_idx * sin_row_stride,
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
63
|
+
)
|
64
|
+
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
@@ -118,7 +129,6 @@ def _triton_rope(
|
|
118
129
|
|
119
130
|
|
120
131
|
def rope_forward(q, k, cos, sin):
|
121
|
-
|
122
132
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
123
133
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
124
134
|
q = q.transpose(1, 2)
|
@@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
|
|
138
148
|
k = k.contiguous()
|
139
149
|
cos = cos.contiguous()
|
140
150
|
sin = sin.contiguous()
|
151
|
+
cos_batch_size = cos.shape[0]
|
141
152
|
|
142
153
|
_triton_rope[(n_row,)](
|
143
154
|
q,
|
@@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
|
|
150
161
|
sin.stride(-2),
|
151
162
|
seq_len,
|
152
163
|
batch_size,
|
164
|
+
cos_batch_size,
|
153
165
|
n_q_head,
|
154
166
|
n_kv_head,
|
155
167
|
head_dim,
|
@@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
167
179
|
dk = dk.transpose(1, 2)
|
168
180
|
|
169
181
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
182
|
+
cos_batch_size = cos.shape[0]
|
170
183
|
n_kv_head = dk.shape[2]
|
171
184
|
pad_hd = triton.next_power_of_2(head_dim)
|
172
185
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
@@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
191
204
|
sin.stride(-2),
|
192
205
|
seq_len,
|
193
206
|
batch_size,
|
207
|
+
cos_batch_size,
|
194
208
|
n_q_head,
|
195
209
|
n_kv_head,
|
196
210
|
head_dim,
|
@@ -221,8 +235,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
221
235
|
"""
|
222
236
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
223
237
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
224
|
-
cos size: (1, seq_len, head_dim)
|
225
|
-
sin size: (1, seq_len, head_dim)
|
238
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
239
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
226
240
|
"""
|
227
241
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
228
242
|
ctx.save_for_backward(cos, sin)
|
@@ -232,8 +246,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
232
246
|
"""
|
233
247
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
234
248
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
235
|
-
cos size: (1, seq_len, head_dim)
|
236
|
-
sin size: (1, seq_len, head_dim)
|
249
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
250
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
237
251
|
"""
|
238
252
|
|
239
253
|
cos, sin = ctx.saved_tensors
|
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
8
8
|
Args:
|
9
9
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
10
10
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
13
13
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
14
14
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
15
15
|
|
@@ -21,7 +21,7 @@ liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,84
|
|
21
21
|
liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
|
22
22
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
|
23
23
|
liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
|
24
|
-
liger_kernel/ops/rope.py,sha256=
|
24
|
+
liger_kernel/ops/rope.py,sha256=KyjbI6ya6bDwmdBJKK1IamuTUMpAmfdsHFYRJ4d9cP8,9059
|
25
25
|
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
26
26
|
liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
|
27
27
|
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
@@ -40,7 +40,7 @@ liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbU
|
|
40
40
|
liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
|
41
41
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
42
|
liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
|
43
|
-
liger_kernel/transformers/rope.py,sha256=
|
43
|
+
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
44
44
|
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
45
45
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
46
|
liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
|
59
59
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/METADATA,sha256=rY2y3vkXwGKfZpmRsIIbD9BwAVpeYe6wbVwKJbMWB8k,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|