liger-kernel 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
10
10
  cos,
11
11
  sin,
12
12
  sl,
13
+ bs: tl.constexpr,
13
14
  n_qh: tl.constexpr,
14
15
  n_kh: tl.constexpr,
15
16
  hd: tl.constexpr,
@@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
41
42
  t_end = mrope_section_t
42
43
  h_end = t_end + mrope_section_h
43
44
 
44
- cos_row_idx = pid % sl
45
- t_cos = cos + cos_row_idx * hd
46
- h_cos = t_cos + sl * hd
47
- w_cos = h_cos + sl * hd
48
- t_sin = sin + cos_row_idx * hd
49
- h_sin = t_sin + sl * hd
50
- w_sin = h_sin + sl * hd
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
51
 
52
52
  cos_offsets = tl.arange(0, pad_hd // 2)
53
53
  t_mask = cos_offsets < t_end
@@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
151
151
  cos,
152
152
  sin,
153
153
  seq_len,
154
+ batch_size,
154
155
  n_q_head,
155
156
  n_kv_head,
156
157
  head_dim,
@@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
189
190
  cos,
190
191
  sin,
191
192
  seq_len,
193
+ batch_size,
192
194
  n_q_head,
193
195
  n_kv_head,
194
196
  head_dim,
@@ -216,8 +218,8 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
216
218
  """
217
219
  q size: (bsz, n_q_head, seq_len, head_dim)
218
220
  k size: (bsz, n_kv_head, seq_len, head_dim)
219
- cos size: (3, 1, seq_len, head_dim)
220
- sin size: (3, 1, seq_len, head_dim)
221
+ cos size: (3, bsz, seq_len, head_dim)
222
+ sin size: (3, bsz, seq_len, head_dim)
221
223
  """
222
224
  q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
225
  ctx.save_for_backward(cos, sin)
@@ -228,10 +230,9 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
228
230
  """
229
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
230
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
231
- cos size: (3, 1, seq_len, head_dim)
232
- sin size: (3, 1, seq_len, head_dim)
233
+ cos size: (3, bsz, seq_len, head_dim)
234
+ sin size: (3, bsz, seq_len, head_dim)
233
235
  """
234
-
235
236
  cos, sin = ctx.saved_tensors
236
237
  mrope_section = ctx.mrope_section
237
238
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
@@ -22,7 +22,6 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
22
22
  apply_liger_kernel_to_qwen2,
23
23
  apply_liger_kernel_to_qwen2_vl,
24
24
  )
25
- from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
26
25
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
27
26
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
28
27
  from liger_kernel.transformers.swiglu import ( # noqa: F401
@@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
13
13
  mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -0,0 +1,6 @@
1
+ try:
2
+ from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
+ LigerORPOTrainer,
4
+ )
5
+ except ImportError:
6
+ raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
@@ -76,9 +76,7 @@ class LigerORPOTrainer(ORPOTrainer):
76
76
  padding_value=self.padding_value,
77
77
  device=self.accelerator.device,
78
78
  )
79
- # if self.accelerator.is_main_process:
80
- # import pdb; pdb.set_trace()
81
- # torch.distributed.barrier()
79
+
82
80
  model_kwargs = (
83
81
  {
84
82
  "decoder_input_ids": self._shift_right(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -34,9 +34,10 @@ Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: transformers
36
36
  Requires-Dist: transformers~=4.0; extra == "transformers"
37
+ Provides-Extra: trl
38
+ Requires-Dist: trl>=0.11.0; extra == "trl"
37
39
  Provides-Extra: dev
38
40
  Requires-Dist: transformers>=4.44.2; extra == "dev"
39
- Requires-Dist: trl>=0.11.0; extra == "dev"
40
41
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
41
42
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
42
43
  Requires-Dist: black>=24.4.2; extra == "dev"
@@ -18,14 +18,14 @@ liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d
18
18
  liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
19
19
  liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
20
20
  liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
21
- liger_kernel/ops/qwen2vl_mrope.py,sha256=xZvQnhkSTjU-k6KiiRn9e0SYO1ESs1jmuZFMICduLpc,8552
21
+ liger_kernel/ops/qwen2vl_mrope.py,sha256=GvP4Cg-2ClYyiqbe7bB_OMvnlZooBmqP2-9V8RMPde4,8598
22
22
  liger_kernel/ops/rms_norm.py,sha256=g7OXwuYI8-LXudDwvXuiupVjjOsbu8c4wwv83VaHa54,11750
23
23
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
24
24
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
25
25
  liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
26
26
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
27
27
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
28
- liger_kernel/transformers/__init__.py,sha256=P5JR3fI-znhG92nRrFS2j0TIJTLhP-xD5dvEy4HP9ik,1418
28
+ liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
29
29
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
30
30
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
31
31
  liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
@@ -37,8 +37,7 @@ liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_
37
37
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
38
38
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
39
39
  liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
40
- liger_kernel/transformers/orpo_trainer.py,sha256=mC8ePS-Oq-BrdM0lKpgSBLuYLqYsWxH_4Q2RnDthz5M,7643
41
- liger_kernel/transformers/qwen2vl_mrope.py,sha256=SfSQVwOe7ArrVfpmIdfZrdzCxmcj7V-YQp9zDu17-ao,1043
40
+ liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
42
41
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
43
42
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
44
43
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -54,11 +53,13 @@ liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_
54
53
  liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
55
54
  liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
56
55
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
56
+ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
57
+ liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
57
58
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
58
59
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
59
- liger_kernel-0.5.0.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
60
- liger_kernel-0.5.0.dist-info/METADATA,sha256=7c5Tzf84zfQFOdXxx5nXg0wqGKH8VhsLCfTvoMN3kNM,20675
61
- liger_kernel-0.5.0.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
62
- liger_kernel-0.5.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
63
- liger_kernel-0.5.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
64
- liger_kernel-0.5.0.dist-info/RECORD,,
60
+ liger_kernel-0.5.2.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
+ liger_kernel-0.5.2.dist-info/METADATA,sha256=olSIT-Jd2Mowu2ja4QLwyPYBhCnY22znBq9pV7stkKI,20695
62
+ liger_kernel-0.5.2.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
+ liger_kernel-0.5.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
64
+ liger_kernel-0.5.2.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
+ liger_kernel-0.5.2.dist-info/RECORD,,