liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223032630__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/rope.py +23 -9
- liger_kernel/transformers/rope.py +2 -2
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223032630.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rope.py
CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
15
15
|
sin_row_stride,
|
16
16
|
sl,
|
17
17
|
bs: tl.constexpr,
|
18
|
+
cos_bs: tl.constexpr,
|
18
19
|
n_qh: tl.constexpr,
|
19
20
|
n_kh: tl.constexpr,
|
20
21
|
hd: tl.constexpr,
|
@@ -29,7 +30,7 @@ def _triton_rope(
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
31
32
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
34
35
|
pid = tl.program_id(0)
|
35
36
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
48
49
|
# and pid % sl to get the sequence index.
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
50
51
|
# a clone of the left half.
|
51
|
-
|
52
|
-
|
53
|
-
|
52
|
+
batch_idx = pid // sl
|
53
|
+
cos_row_idx = pid % sl
|
54
|
+
cos = cos + tl.where(
|
55
|
+
cos_bs == 1,
|
56
|
+
cos_row_idx * cos_row_stride,
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
58
|
+
)
|
59
|
+
sin = sin + tl.where(
|
60
|
+
cos_bs == 1,
|
61
|
+
cos_row_idx * sin_row_stride,
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
63
|
+
)
|
64
|
+
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
@@ -118,7 +129,6 @@ def _triton_rope(
|
|
118
129
|
|
119
130
|
|
120
131
|
def rope_forward(q, k, cos, sin):
|
121
|
-
|
122
132
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
123
133
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
124
134
|
q = q.transpose(1, 2)
|
@@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
|
|
138
148
|
k = k.contiguous()
|
139
149
|
cos = cos.contiguous()
|
140
150
|
sin = sin.contiguous()
|
151
|
+
cos_batch_size = cos.shape[0]
|
141
152
|
|
142
153
|
_triton_rope[(n_row,)](
|
143
154
|
q,
|
@@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
|
|
150
161
|
sin.stride(-2),
|
151
162
|
seq_len,
|
152
163
|
batch_size,
|
164
|
+
cos_batch_size,
|
153
165
|
n_q_head,
|
154
166
|
n_kv_head,
|
155
167
|
head_dim,
|
@@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
167
179
|
dk = dk.transpose(1, 2)
|
168
180
|
|
169
181
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
182
|
+
cos_batch_size = cos.shape[0]
|
170
183
|
n_kv_head = dk.shape[2]
|
171
184
|
pad_hd = triton.next_power_of_2(head_dim)
|
172
185
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
@@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
191
204
|
sin.stride(-2),
|
192
205
|
seq_len,
|
193
206
|
batch_size,
|
207
|
+
cos_batch_size,
|
194
208
|
n_q_head,
|
195
209
|
n_kv_head,
|
196
210
|
head_dim,
|
@@ -221,8 +235,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
221
235
|
"""
|
222
236
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
223
237
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
224
|
-
cos size: (1, seq_len, head_dim)
|
225
|
-
sin size: (1, seq_len, head_dim)
|
238
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
239
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
226
240
|
"""
|
227
241
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
228
242
|
ctx.save_for_backward(cos, sin)
|
@@ -232,8 +246,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
232
246
|
"""
|
233
247
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
234
248
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
235
|
-
cos size: (1, seq_len, head_dim)
|
236
|
-
sin size: (1, seq_len, head_dim)
|
249
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
250
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
237
251
|
"""
|
238
252
|
|
239
253
|
cos, sin = ctx.saved_tensors
|
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
8
8
|
Args:
|
9
9
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
10
10
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
13
13
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
14
14
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
15
15
|
|
@@ -21,7 +21,7 @@ liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,84
|
|
21
21
|
liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
|
22
22
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
|
23
23
|
liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
|
24
|
-
liger_kernel/ops/rope.py,sha256=
|
24
|
+
liger_kernel/ops/rope.py,sha256=KyjbI6ya6bDwmdBJKK1IamuTUMpAmfdsHFYRJ4d9cP8,9059
|
25
25
|
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
26
26
|
liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
|
27
27
|
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
@@ -40,7 +40,7 @@ liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbU
|
|
40
40
|
liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
|
41
41
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
42
42
|
liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
|
43
|
-
liger_kernel/transformers/rope.py,sha256=
|
43
|
+
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
44
44
|
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
45
45
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
46
46
|
liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
|
59
59
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/METADATA,sha256=rY2y3vkXwGKfZpmRsIIbD9BwAVpeYe6wbVwKJbMWB8k,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|