liger-kernel 0.5.0__tar.gz → 0.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {liger_kernel-0.5.0/src/liger_kernel.egg-info → liger_kernel-0.5.2}/PKG-INFO +3 -2
  2. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/pyproject.toml +5 -2
  3. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/qwen2vl_mrope.py +13 -12
  4. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/__init__.py +0 -1
  5. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/qwen2vl_mrope.py +2 -2
  6. liger_kernel-0.5.2/src/liger_kernel/transformers/trainer/__init__.py +6 -0
  7. {liger_kernel-0.5.0/src/liger_kernel/transformers → liger_kernel-0.5.2/src/liger_kernel/transformers/trainer}/orpo_trainer.py +1 -3
  8. {liger_kernel-0.5.0 → liger_kernel-0.5.2/src/liger_kernel.egg-info}/PKG-INFO +3 -2
  9. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel.egg-info/SOURCES.txt +2 -1
  10. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel.egg-info/requires.txt +3 -1
  11. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/LICENSE +0 -0
  12. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/NOTICE +0 -0
  13. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/README.md +0 -0
  14. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/setup.cfg +0 -0
  15. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/__init__.py +0 -0
  16. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  17. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  18. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  19. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/functional.py +0 -0
  20. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  21. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  22. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  23. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  24. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/env_report.py +0 -0
  25. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/__init__.py +0 -0
  26. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/cross_entropy.py +0 -0
  27. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  28. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  29. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  30. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  31. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/geglu.py +0 -0
  32. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/group_norm.py +0 -0
  33. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/jsd.py +0 -0
  34. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/kl_div.py +0 -0
  35. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/layer_norm.py +0 -0
  36. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/rms_norm.py +0 -0
  37. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/rope.py +0 -0
  38. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/swiglu.py +0 -0
  39. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/ops/utils.py +0 -0
  40. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/auto_model.py +0 -0
  41. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  42. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  43. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/functional.py +0 -0
  44. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  45. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  46. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/geglu.py +0 -0
  47. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/group_norm.py +0 -0
  48. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/jsd.py +0 -0
  49. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/kl_div.py +0 -0
  50. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/layer_norm.py +0 -0
  51. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/__init__.py +0 -0
  52. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/gemma.py +0 -0
  53. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  54. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/llama.py +0 -0
  55. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/mistral.py +0 -0
  56. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  57. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/mllama.py +0 -0
  58. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/phi3.py +0 -0
  59. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  60. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  61. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  62. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/rms_norm.py +0 -0
  63. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/rope.py +0 -0
  64. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/swiglu.py +0 -0
  65. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  66. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/triton/__init__.py +0 -0
  67. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/triton/monkey_patch.py +0 -0
  68. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel/utils.py +0 -0
  69. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  70. {liger_kernel-0.5.0 → liger_kernel-0.5.2}/src/liger_kernel.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -34,9 +34,10 @@ Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: transformers
36
36
  Requires-Dist: transformers~=4.0; extra == "transformers"
37
+ Provides-Extra: trl
38
+ Requires-Dist: trl>=0.11.0; extra == "trl"
37
39
  Provides-Extra: dev
38
40
  Requires-Dist: transformers>=4.44.2; extra == "dev"
39
- Requires-Dist: trl>=0.11.0; extra == "dev"
40
41
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
41
42
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
42
43
  Requires-Dist: black>=24.4.2; extra == "dev"
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.0"
7
+ version = "0.5.2"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -20,9 +20,12 @@ transformers = [
20
20
  "transformers~=4.0"
21
21
  ]
22
22
 
23
+ trl = [
24
+ "trl>=0.11.0",
25
+ ]
26
+
23
27
  dev = [
24
28
  "transformers>=4.44.2",
25
- "trl>=0.11.0",
26
29
  "matplotlib>=3.7.2",
27
30
  "flake8>=4.0.1.1",
28
31
  "black>=24.4.2",
@@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
10
10
  cos,
11
11
  sin,
12
12
  sl,
13
+ bs: tl.constexpr,
13
14
  n_qh: tl.constexpr,
14
15
  n_kh: tl.constexpr,
15
16
  hd: tl.constexpr,
@@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
41
42
  t_end = mrope_section_t
42
43
  h_end = t_end + mrope_section_h
43
44
 
44
- cos_row_idx = pid % sl
45
- t_cos = cos + cos_row_idx * hd
46
- h_cos = t_cos + sl * hd
47
- w_cos = h_cos + sl * hd
48
- t_sin = sin + cos_row_idx * hd
49
- h_sin = t_sin + sl * hd
50
- w_sin = h_sin + sl * hd
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
51
 
52
52
  cos_offsets = tl.arange(0, pad_hd // 2)
53
53
  t_mask = cos_offsets < t_end
@@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
151
151
  cos,
152
152
  sin,
153
153
  seq_len,
154
+ batch_size,
154
155
  n_q_head,
155
156
  n_kv_head,
156
157
  head_dim,
@@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
189
190
  cos,
190
191
  sin,
191
192
  seq_len,
193
+ batch_size,
192
194
  n_q_head,
193
195
  n_kv_head,
194
196
  head_dim,
@@ -216,8 +218,8 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
216
218
  """
217
219
  q size: (bsz, n_q_head, seq_len, head_dim)
218
220
  k size: (bsz, n_kv_head, seq_len, head_dim)
219
- cos size: (3, 1, seq_len, head_dim)
220
- sin size: (3, 1, seq_len, head_dim)
221
+ cos size: (3, bsz, seq_len, head_dim)
222
+ sin size: (3, bsz, seq_len, head_dim)
221
223
  """
222
224
  q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
225
  ctx.save_for_backward(cos, sin)
@@ -228,10 +230,9 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
228
230
  """
229
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
230
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
231
- cos size: (3, 1, seq_len, head_dim)
232
- sin size: (3, 1, seq_len, head_dim)
233
+ cos size: (3, bsz, seq_len, head_dim)
234
+ sin size: (3, bsz, seq_len, head_dim)
233
235
  """
234
-
235
236
  cos, sin = ctx.saved_tensors
236
237
  mrope_section = ctx.mrope_section
237
238
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
@@ -22,7 +22,6 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
22
22
  apply_liger_kernel_to_qwen2,
23
23
  apply_liger_kernel_to_qwen2_vl,
24
24
  )
25
- from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
26
25
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
27
26
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
28
27
  from liger_kernel.transformers.swiglu import ( # noqa: F401
@@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
13
13
  mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -0,0 +1,6 @@
1
+ try:
2
+ from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
+ LigerORPOTrainer,
4
+ )
5
+ except ImportError:
6
+ raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
@@ -76,9 +76,7 @@ class LigerORPOTrainer(ORPOTrainer):
76
76
  padding_value=self.padding_value,
77
77
  device=self.accelerator.device,
78
78
  )
79
- # if self.accelerator.is_main_process:
80
- # import pdb; pdb.set_trace()
81
- # torch.distributed.barrier()
79
+
82
80
  model_kwargs = (
83
81
  {
84
82
  "decoder_input_ids": self._shift_right(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -34,9 +34,10 @@ Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: transformers
36
36
  Requires-Dist: transformers~=4.0; extra == "transformers"
37
+ Provides-Extra: trl
38
+ Requires-Dist: trl>=0.11.0; extra == "trl"
37
39
  Provides-Extra: dev
38
40
  Requires-Dist: transformers>=4.44.2; extra == "dev"
39
- Requires-Dist: trl>=0.11.0; extra == "dev"
40
41
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
41
42
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
42
43
  Requires-Dist: black>=24.4.2; extra == "dev"
@@ -46,7 +46,6 @@ src/liger_kernel/transformers/jsd.py
46
46
  src/liger_kernel/transformers/kl_div.py
47
47
  src/liger_kernel/transformers/layer_norm.py
48
48
  src/liger_kernel/transformers/monkey_patch.py
49
- src/liger_kernel/transformers/orpo_trainer.py
50
49
  src/liger_kernel/transformers/qwen2vl_mrope.py
51
50
  src/liger_kernel/transformers/rms_norm.py
52
51
  src/liger_kernel/transformers/rope.py
@@ -63,5 +62,7 @@ src/liger_kernel/transformers/model/mllama.py
63
62
  src/liger_kernel/transformers/model/phi3.py
64
63
  src/liger_kernel/transformers/model/qwen2.py
65
64
  src/liger_kernel/transformers/model/qwen2_vl.py
65
+ src/liger_kernel/transformers/trainer/__init__.py
66
+ src/liger_kernel/transformers/trainer/orpo_trainer.py
66
67
  src/liger_kernel/triton/__init__.py
67
68
  src/liger_kernel/triton/monkey_patch.py
@@ -9,7 +9,6 @@ triton>=3.0.0
9
9
 
10
10
  [dev]
11
11
  transformers>=4.44.2
12
- trl>=0.11.0
13
12
  matplotlib>=3.7.2
14
13
  flake8>=4.0.1.1
15
14
  black>=24.4.2
@@ -23,3 +22,6 @@ seaborn
23
22
 
24
23
  [transformers]
25
24
  transformers~=4.0
25
+
26
+ [trl]
27
+ trl>=0.11.0
File without changes
File without changes
File without changes
File without changes