liger-kernel-nightly 0.5.1.dev20241210095557__py3-none-any.whl → 0.5.1.dev20241210172102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
10
10
  cos,
11
11
  sin,
12
12
  sl,
13
+ bs: tl.constexpr,
13
14
  n_qh: tl.constexpr,
14
15
  n_kh: tl.constexpr,
15
16
  hd: tl.constexpr,
@@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
41
42
  t_end = mrope_section_t
42
43
  h_end = t_end + mrope_section_h
43
44
 
44
- cos_row_idx = pid % sl
45
- t_cos = cos + cos_row_idx * hd
46
- h_cos = t_cos + sl * hd
47
- w_cos = h_cos + sl * hd
48
- t_sin = sin + cos_row_idx * hd
49
- h_sin = t_sin + sl * hd
50
- w_sin = h_sin + sl * hd
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
51
 
52
52
  cos_offsets = tl.arange(0, pad_hd // 2)
53
53
  t_mask = cos_offsets < t_end
@@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
151
151
  cos,
152
152
  sin,
153
153
  seq_len,
154
+ batch_size,
154
155
  n_q_head,
155
156
  n_kv_head,
156
157
  head_dim,
@@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
189
190
  cos,
190
191
  sin,
191
192
  seq_len,
193
+ batch_size,
192
194
  n_q_head,
193
195
  n_kv_head,
194
196
  head_dim,
@@ -216,8 +218,8 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
216
218
  """
217
219
  q size: (bsz, n_q_head, seq_len, head_dim)
218
220
  k size: (bsz, n_kv_head, seq_len, head_dim)
219
- cos size: (3, 1, seq_len, head_dim)
220
- sin size: (3, 1, seq_len, head_dim)
221
+ cos size: (3, bsz, seq_len, head_dim)
222
+ sin size: (3, bsz, seq_len, head_dim)
221
223
  """
222
224
  q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
225
  ctx.save_for_backward(cos, sin)
@@ -228,10 +230,9 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
228
230
  """
229
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
230
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
231
- cos size: (3, 1, seq_len, head_dim)
232
- sin size: (3, 1, seq_len, head_dim)
233
+ cos size: (3, bsz, seq_len, head_dim)
234
+ sin size: (3, bsz, seq_len, head_dim)
233
235
  """
234
-
235
236
  cos, sin = ctx.saved_tensors
236
237
  mrope_section = ctx.mrope_section
237
238
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
@@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
13
13
  mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.1.dev20241210095557
3
+ Version: 0.5.1.dev20241210172102
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -18,7 +18,7 @@ liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d
18
18
  liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
19
19
  liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
20
20
  liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
21
- liger_kernel/ops/qwen2vl_mrope.py,sha256=xZvQnhkSTjU-k6KiiRn9e0SYO1ESs1jmuZFMICduLpc,8552
21
+ liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
22
22
  liger_kernel/ops/rms_norm.py,sha256=g7OXwuYI8-LXudDwvXuiupVjjOsbu8c4wwv83VaHa54,11750
23
23
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
24
24
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
@@ -37,7 +37,7 @@ liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_
37
37
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
38
38
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
39
39
  liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
40
- liger_kernel/transformers/qwen2vl_mrope.py,sha256=SfSQVwOe7ArrVfpmIdfZrdzCxmcj7V-YQp9zDu17-ao,1043
40
+ liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
41
41
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
42
42
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
43
43
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -57,9 +57,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
57
57
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
58
58
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
59
59
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
60
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/METADATA,sha256=yj_0kUeJuqUfLhxbnMPiZz3Bpgi9TOz6_pGaZoKDDa8,20721
62
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
64
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
- liger_kernel_nightly-0.5.1.dev20241210095557.dist-info/RECORD,,
60
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/METADATA,sha256=2uxguJQFvHMsixaUtTeLF71EpHSwnsXKe33bQfR4Bbo,20721
62
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
64
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
+ liger_kernel_nightly-0.5.1.dev20241210172102.dist-info/RECORD,,