liger-kernel-nightly 0.5.2.dev20241220231758__py3-none-any.whl → 0.5.2.dev20241223032630__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,8 +17,8 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
17
  else:
18
18
  from triton.language.math import tanh
19
19
 
20
- _TRUE = tl.constexpr(1)
21
- _FALSE = tl.constexpr(0)
20
+ _TRUE: tl.constexpr = tl.constexpr(1)
21
+ _FALSE: tl.constexpr = tl.constexpr(0)
22
22
 
23
23
 
24
24
  @triton.jit
@@ -23,10 +23,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
23
 
24
24
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
25
 
26
- _REDUCTION_MODE_NONE = tl.constexpr(0)
27
- _REDUCTION_MODE_SUM = tl.constexpr(1)
28
- _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
- _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
26
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
27
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
28
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
29
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
30
30
 
31
31
  _str_to_reduction_mode = {
32
32
  "none": _REDUCTION_MODE_NONE.value,
@@ -35,9 +35,9 @@ else:
35
35
  from triton.language.math import rsqrt
36
36
 
37
37
 
38
- _CASTING_MODE_NONE = tl.constexpr(-1)
39
- _CASTING_MODE_LLAMA = tl.constexpr(0)
40
- _CASTING_MODE_GEMMA = tl.constexpr(1)
38
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
39
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
40
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
41
41
 
42
42
 
43
43
  @triton.jit
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -118,7 +129,6 @@ def _triton_rope(
118
129
 
119
130
 
120
131
  def rope_forward(q, k, cos, sin):
121
-
122
132
  # transpose it back to the physical shape because Triton looks at the physical storage
123
133
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
134
  q = q.transpose(1, 2)
@@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
138
148
  k = k.contiguous()
139
149
  cos = cos.contiguous()
140
150
  sin = sin.contiguous()
151
+ cos_batch_size = cos.shape[0]
141
152
 
142
153
  _triton_rope[(n_row,)](
143
154
  q,
@@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
150
161
  sin.stride(-2),
151
162
  seq_len,
152
163
  batch_size,
164
+ cos_batch_size,
153
165
  n_q_head,
154
166
  n_kv_head,
155
167
  head_dim,
@@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
167
179
  dk = dk.transpose(1, 2)
168
180
 
169
181
  batch_size, seq_len, n_q_head, head_dim = dq.shape
182
+ cos_batch_size = cos.shape[0]
170
183
  n_kv_head = dk.shape[2]
171
184
  pad_hd = triton.next_power_of_2(head_dim)
172
185
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
191
204
  sin.stride(-2),
192
205
  seq_len,
193
206
  batch_size,
207
+ cos_batch_size,
194
208
  n_q_head,
195
209
  n_kv_head,
196
210
  head_dim,
@@ -221,8 +235,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
235
  """
222
236
  q size: (bsz, n_q_head, seq_len, head_dim)
223
237
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
238
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
239
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
240
  """
227
241
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
242
  ctx.save_for_backward(cos, sin)
@@ -232,8 +246,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
246
  """
233
247
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
248
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
249
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
250
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
251
  """
238
252
 
239
253
  cos, sin = ctx.saved_tensors
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
13
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241220231758
3
+ Version: 0.5.2.dev20241223032630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -11,17 +11,17 @@ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vvratrj8rba8NaGbO2ff
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=xHsKjlCWQVew7_hhpyUp3a1wd0tdpgx-zQAezNjk3Q4,3532
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=_5gXIkEAT0Kt_AufziQlYhBjzDJVSQVk7oSDHcrw1xw,3759
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=oG5hfrlmnlF5lOoZRhHRglObxgH4B0KadjWMJj9EWPM,15860
14
+ liger_kernel/ops/cross_entropy.py,sha256=3oPrw6KzIVc11gSyfdrLnj0WJB4qOYjE1tC8HJeFFpg,15888
15
15
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
17
17
  liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
18
18
  liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
19
19
  liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
20
- liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
20
+ liger_kernel/ops/kl_div.py,sha256=vBz1ieu_sPcFbgG_wL0SwrbSQ6xVDK51_FNo-yf7CjY,8430
21
21
  liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
22
22
  liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
23
- liger_kernel/ops/rms_norm.py,sha256=g7OXwuYI8-LXudDwvXuiupVjjOsbu8c4wwv83VaHa54,11750
24
- liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
23
+ liger_kernel/ops/rms_norm.py,sha256=bleuRC9IS_P3zEX07b0LZ_cpgeTH8l5sdvkelucpRgM,11792
24
+ liger_kernel/ops/rope.py,sha256=KyjbI6ya6bDwmdBJKK1IamuTUMpAmfdsHFYRJ4d9cP8,9059
25
25
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
26
26
  liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
27
27
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
@@ -40,7 +40,7 @@ liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbU
40
40
  liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
41
41
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
42
42
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
43
- liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
43
+ liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
44
44
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
45
45
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
46
46
  liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
59
59
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
60
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/METADATA,sha256=o8KNSXeyS1E1vgQVqX7pZRdbzCXPDeG2iaGDZ2a2_mM,21055
63
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241220231758.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/METADATA,sha256=rY2y3vkXwGKfZpmRsIIbD9BwAVpeYe6wbVwKJbMWB8k,21055
63
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD,,