liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223032630__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -118,7 +129,6 @@ def _triton_rope(
118
129
 
119
130
 
120
131
  def rope_forward(q, k, cos, sin):
121
-
122
132
  # transpose it back to the physical shape because Triton looks at the physical storage
123
133
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
134
  q = q.transpose(1, 2)
@@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
138
148
  k = k.contiguous()
139
149
  cos = cos.contiguous()
140
150
  sin = sin.contiguous()
151
+ cos_batch_size = cos.shape[0]
141
152
 
142
153
  _triton_rope[(n_row,)](
143
154
  q,
@@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
150
161
  sin.stride(-2),
151
162
  seq_len,
152
163
  batch_size,
164
+ cos_batch_size,
153
165
  n_q_head,
154
166
  n_kv_head,
155
167
  head_dim,
@@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
167
179
  dk = dk.transpose(1, 2)
168
180
 
169
181
  batch_size, seq_len, n_q_head, head_dim = dq.shape
182
+ cos_batch_size = cos.shape[0]
170
183
  n_kv_head = dk.shape[2]
171
184
  pad_hd = triton.next_power_of_2(head_dim)
172
185
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
191
204
  sin.stride(-2),
192
205
  seq_len,
193
206
  batch_size,
207
+ cos_batch_size,
194
208
  n_q_head,
195
209
  n_kv_head,
196
210
  head_dim,
@@ -221,8 +235,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
235
  """
222
236
  q size: (bsz, n_q_head, seq_len, head_dim)
223
237
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
238
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
239
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
240
  """
227
241
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
242
  ctx.save_for_backward(cos, sin)
@@ -232,8 +246,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
246
  """
233
247
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
248
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
249
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
250
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
251
  """
238
252
 
239
253
  cos, sin = ctx.saved_tensors
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
13
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241223032015
3
+ Version: 0.5.2.dev20241223032630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -21,7 +21,7 @@ liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,84
21
21
  liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
22
22
  liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
23
23
  liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
24
- liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
24
+ liger_kernel/ops/rope.py,sha256=KyjbI6ya6bDwmdBJKK1IamuTUMpAmfdsHFYRJ4d9cP8,9059
25
25
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
26
26
  liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
27
27
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
@@ -40,7 +40,7 @@ liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbU
40
40
  liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
41
41
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
42
42
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
43
- liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
43
+ liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
44
44
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
45
45
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
46
46
  liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
59
59
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
60
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/METADATA,sha256=glSPMysElXhTUr1u74GrG_xjFSIek9GtE9AlPR6GkLs,21055
63
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/METADATA,sha256=rY2y3vkXwGKfZpmRsIIbD9BwAVpeYe6wbVwKJbMWB8k,21055
63
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD,,