liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,11 @@ import torch.nn as nn
5
5
 
6
6
  from torch.nn.modules.utils import _pair
7
7
 
8
- from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
8
+ from liger_kernel.ops import LigerMultiTokenAttentionFunction
9
9
 
10
10
 
11
11
  class LigerMultiTokenAttention(nn.Module):
12
- """
12
+ r"""
13
13
  Multi-Token Attention:
14
14
  out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
15
 
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops import LigerPolyNormFunction
5
+
6
+
7
+ class LigerPolyNorm(nn.Module):
8
+ """
9
+ PolyNorm layer wrapper for Liger kernel.
10
+
11
+ PolyNorm formula:
12
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
13
+ where norm(u) = u / sqrt(mean(u²) + ε)
14
+
15
+ Reference:
16
+ https://github.com/BryceZhuo/PolyCom/
17
+
18
+ Args:
19
+ eps: epsilon for numerical stability (default: 1e-6)
20
+ in_place: whether to in-place modify grad_output in backward to save memory (default: False).
21
+ Set to True to save memory if grad_output is not needed elsewhere.
22
+ """
23
+
24
+ def __init__(self, eps=1e-6, in_place=True):
25
+ super().__init__()
26
+ # Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
27
+ self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
28
+ self.bias = nn.Parameter(torch.tensor(1.0))
29
+ self.variance_epsilon = eps
30
+ self.in_place = in_place
31
+
32
+ def forward(self, hidden_states):
33
+ return LigerPolyNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.bias,
37
+ self.variance_epsilon,
38
+ self.in_place,
39
+ )
40
+
41
+ def extra_repr(self):
42
+ return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
@@ -1,4 +1,4 @@
1
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
1
+ from liger_kernel.ops import LigerQwen2VLMRopeFunction
2
2
 
3
3
 
4
4
  def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
4
+ from liger_kernel.ops import LigerRMSNormFunction
5
5
 
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
@@ -14,13 +14,18 @@ class LigerRMSNorm(nn.Module):
14
14
  init_fn="ones",
15
15
  in_place=True,
16
16
  row_mode=None,
17
+ elementwise_affine=True,
17
18
  ):
18
19
  super().__init__()
19
20
  assert init_fn in [
20
21
  "ones",
21
22
  "zeros",
22
23
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
23
- self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
24
+ self.elementwise_affine = elementwise_affine
25
+ if self.elementwise_affine:
26
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
27
+ else:
28
+ self.register_parameter("weight", None)
24
29
  self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
25
30
  eps,
26
31
  offset,
@@ -41,9 +46,7 @@ class LigerRMSNorm(nn.Module):
41
46
  )
42
47
 
43
48
  def extra_repr(self):
44
- return (
45
- f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
46
- )
49
+ return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
47
50
 
48
51
 
49
52
  class LigerRMSNormForGemma(LigerRMSNorm):
@@ -79,3 +82,10 @@ class LigerRMSNormForGlm4(LigerRMSNorm):
79
82
  self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
80
83
  ):
81
84
  super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
85
+
86
+
87
+ class LigerRMSNormForQwen3Next(LigerRMSNorm):
88
+ def __init__(
89
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
90
+ ):
91
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,4 +1,8 @@
1
- from liger_kernel.ops.rope import LigerRopeFunction
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from liger_kernel.ops import LigerRopeFunction
2
6
 
3
7
 
4
8
  def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
@@ -18,3 +22,43 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
18
22
  """
19
23
 
20
24
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
25
+
26
+
27
+ def liger_rotary_pos_emb_vision(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ cos: torch.Tensor,
31
+ sin: torch.Tensor,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
35
+ Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
36
+ Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
37
+
38
+ Args:
39
+ q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
40
+ with stride (num_heads * head_dim, head_dim, 1).
41
+ k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
42
+ with stride (num_heads * head_dim, head_dim, 1). Same as q.
43
+ cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
44
+ sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
48
+ """
49
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
50
+
51
+ # tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
52
+ # also unsqueeze for batch dim
53
+ q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
54
+ k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
55
+ cos32 = cos.to(torch.float32)
56
+ sin32 = sin.to(torch.float32)
57
+
58
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
59
+
60
+ # transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
61
+ # also squeeze out batch dim
62
+ q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
63
+ k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
64
+ return q_out, k_out
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.softmax import LigerSoftmaxFunction
4
+ from liger_kernel.ops import LigerSoftmaxFunction
5
5
 
6
6
 
7
7
  class LigerSoftmax(nn.Module):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
4
+ from liger_kernel.ops import LigerSparsemaxFunction
5
5
 
6
6
 
7
7
  class LigerSparsemax(nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
3
+ from liger_kernel.ops import LigerSiLUMulFunction
4
4
 
5
5
 
6
6
  class LigerSwiGLUMLP(nn.Module):
@@ -77,3 +77,20 @@ class LigerQwen3MoeSwiGLUMLP(nn.Module):
77
77
 
78
78
  def forward(self, x):
79
79
  return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
80
+
81
+
82
+ class LigerHunyuanV1SwiGLUMLP(nn.Module):
83
+ def __init__(self, config, layer_idx=None, is_shared_mlp=False):
84
+ super().__init__()
85
+ self.config = config
86
+ self.hidden_size = config.hidden_size
87
+ self.intermediate_size = config.intermediate_size
88
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
89
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
90
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
91
+ self.layer_idx = layer_idx
92
+ if config.hidden_act not in ["silu", "swish"]:
93
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
94
+
95
+ def forward(self, x):
96
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -0,0 +1,125 @@
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+
5
+ from liger_kernel.ops import LigerGELUMulFunction
6
+ from liger_kernel.ops import LigerSiLUMulFunction
7
+ from liger_kernel.ops import apply_tiled_mlp
8
+
9
+
10
+ class LigerTiledGEGLUMLP(nn.Module):
11
+ """
12
+ Memory-efficient GEGLU MLP using tiled computation.
13
+
14
+ This module combines GEGLU activation with tiled processing to handle
15
+ very long sequences efficiently. The forward pass is recomputed during
16
+ backward to save memory.
17
+
18
+ Args:
19
+ config: Model configuration with hidden_size and intermediate_size attributes
20
+ num_shards: Number of shards to split the sequence. If None, automatically
21
+ calculated as ceil(seqlen / hidden_size)
22
+ """
23
+
24
+ def __init__(self, config, num_shards: Optional[int] = None):
25
+ super().__init__()
26
+ self.config = config
27
+ self.hidden_size = config.hidden_size
28
+ self.intermediate_size = config.intermediate_size
29
+ self.num_shards = num_shards
30
+
31
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
32
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
33
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
34
+
35
+ # Validate activation function
36
+ if hasattr(config, "hidden_act") and config.hidden_act not in [
37
+ "gelu",
38
+ "gelu_new",
39
+ "gelu_pytorch_tanh",
40
+ ]:
41
+ raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
42
+
43
+ def _mlp_forward(self, module, x):
44
+ """Internal MLP forward function for tiled computation."""
45
+ gate = module.gate_proj(x)
46
+ up = module.up_proj(x)
47
+ return module.down_proj(LigerGELUMulFunction.apply(gate, up))
48
+
49
+ def forward(self, x):
50
+ """
51
+ Forward pass with tiled computation.
52
+
53
+ Args:
54
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
55
+ or [seq_len, hidden_size]
56
+
57
+ Returns:
58
+ Output tensor of the same shape as input
59
+ """
60
+ compute_params = [p for p in self.parameters() if p.requires_grad]
61
+
62
+ return apply_tiled_mlp(
63
+ fn=self._mlp_forward,
64
+ mlp_module=self,
65
+ x=x,
66
+ num_shards=self.num_shards,
67
+ compute_params=compute_params,
68
+ )
69
+
70
+
71
+ class LigerTiledSwiGLUMLP(nn.Module):
72
+ """
73
+ Memory-efficient SwiGLU MLP using tiled computation.
74
+
75
+ This module combines SwiGLU activation with tiled processing to handle
76
+ very long sequences efficiently. The forward pass is recomputed during
77
+ backward to save memory.
78
+
79
+ Args:
80
+ config: Model configuration with hidden_size and intermediate_size attributes
81
+ num_shards: Number of shards to split the sequence. If None, automatically
82
+ calculated as ceil(seqlen / hidden_size)
83
+ """
84
+
85
+ def __init__(self, config, num_shards: Optional[int] = None):
86
+ super().__init__()
87
+ self.config = config
88
+ self.hidden_size = config.hidden_size
89
+ self.intermediate_size = config.intermediate_size
90
+ self.num_shards = num_shards
91
+
92
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
93
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
94
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
95
+
96
+ # Validate activation function
97
+ if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
98
+ raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
99
+
100
+ def _mlp_forward(self, module, x):
101
+ """Internal MLP forward function for tiled computation."""
102
+ gate = module.gate_proj(x)
103
+ up = module.up_proj(x)
104
+ return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
105
+
106
+ def forward(self, x):
107
+ """
108
+ Forward pass with tiled computation.
109
+
110
+ Args:
111
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
112
+ or [seq_len, hidden_size]
113
+
114
+ Returns:
115
+ Output tensor of the same shape as input
116
+ """
117
+ compute_params = [p for p in self.parameters() if p.requires_grad]
118
+
119
+ return apply_tiled_mlp(
120
+ fn=self._mlp_forward,
121
+ mlp_module=self,
122
+ x=x,
123
+ num_shards=self.num_shards,
124
+ compute_params=compute_params,
125
+ )
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.tvd import LigerTVDLossFunction
3
+ from liger_kernel.ops import LigerTVDLossFunction
4
4
 
5
5
 
6
6
  class LigerTVDLoss(nn.Module):
liger_kernel/utils.py CHANGED
@@ -12,18 +12,70 @@ def is_peft_available():
12
12
  return PEFT_AVAILABLE
13
13
 
14
14
 
15
+ def infer_comm_backend():
16
+ """
17
+ Get communication backend name based on the environment.
18
+ """
19
+ if torch.distributed.is_nccl_available():
20
+ # Works for Nvidia
21
+ # TODO: nccl may not work for AMD decices that may require use of rccl.
22
+ return "nccl"
23
+ elif is_npu_available():
24
+ # Use Ascend NPU if available (torch.npu)
25
+ # Ascend is not standard torch backend and requires extension.
26
+ # Assume that it is installed if NPUs are being used in
27
+ # multi device environment.
28
+ return "ascend"
29
+ # XPU (Intel) if available
30
+ elif torch.distributed.distributed_c10d.is_xccl_available():
31
+ return "xccl"
32
+ elif torch.distributed.is_mpi_available():
33
+ # CPU backend, first option
34
+ return "mpi"
35
+ elif torch.distributed.is_gloo_available():
36
+ # CPU backend, backup option
37
+ return "gloo"
38
+ else:
39
+ raise RuntimeError("There is no distributed backend available.")
40
+
41
+
15
42
  def infer_device():
16
43
  """
17
44
  Get current device name based on available devices
18
45
  """
19
46
  if torch.cuda.is_available(): # Works for both Nvidia and AMD
20
47
  return "cuda"
48
+ # Use Ascend NPU if available (torch.npu)
49
+ elif is_npu_available():
50
+ return "npu"
51
+ # XPU (Intel) if available
21
52
  elif torch.xpu.is_available():
22
53
  return "xpu"
23
54
  else:
24
55
  return "cpu"
25
56
 
26
57
 
58
+ def is_npu_available() -> bool:
59
+ """Detect Ascend NPU availability."""
60
+ try:
61
+ from transformers.utils import is_torch_npu_available
62
+
63
+ return is_torch_npu_available()
64
+ except Exception:
65
+ return False
66
+
67
+
68
+ def get_npu_multi_processor_count() -> int:
69
+ """Return a heuristic multi-processor count for NPU."""
70
+ if is_npu_available():
71
+ NPU_MULTI_PROCESSOR_COUNT = 48
72
+ dev_props = torch.npu.get_device_properties()
73
+ # The vector_core_num attribute is supported in the torch.npu v7.2.0 release version.
74
+ return dev_props.vector_core_num if hasattr(dev_props, "vector_core_num") else NPU_MULTI_PROCESSOR_COUNT
75
+ # Reasonable default to avoid division by zero
76
+ return 1
77
+
78
+
27
79
  def transformers_version_dispatch(
28
80
  required_version: str,
29
81
  before_fn,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250611191801
3
+ Version: 0.6.4.dev20260112233432
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -33,18 +33,18 @@ License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: dev
36
- Requires-Dist: transformers>=4.44.2; extra == "dev"
36
+ Requires-Dist: transformers>=4.49.0; extra == "dev"
37
37
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
- Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
- Requires-Dist: black>=24.4.2; extra == "dev"
40
- Requires-Dist: isort>=5.13.2; extra == "dev"
38
+ Requires-Dist: ruff>=0.12.0; extra == "dev"
41
39
  Requires-Dist: pytest>=7.1.2; extra == "dev"
42
40
  Requires-Dist: pytest-xdist; extra == "dev"
41
+ Requires-Dist: pytest-cov; extra == "dev"
42
+ Requires-Dist: pytest-asyncio; extra == "dev"
43
43
  Requires-Dist: pytest-rerunfailures; extra == "dev"
44
44
  Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
- Requires-Dist: mkdocs; extra == "dev"
47
46
  Requires-Dist: mkdocs-material; extra == "dev"
47
+ Requires-Dist: torchvision>=0.20; extra == "dev"
48
48
 
49
49
  <a name="readme-top"></a>
50
50
 
@@ -79,8 +79,8 @@ Requires-Dist: mkdocs-material; extra == "dev"
79
79
  </a>
80
80
  </td>
81
81
  <td style="padding: 10px;">
82
- <a href="https://discord.gg/gpumode">
83
- <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
82
+ <a href="https://discord.gg/X4MaxPgA">
83
+ <img src="https://dcbadge.limes.pink/api/server/https://discord.gg/X4MaxPgA?style=flat" alt="Join Our Discord">
84
84
  </a>
85
85
  </td>
86
86
  </tr>
@@ -95,6 +95,7 @@ Requires-Dist: mkdocs-material; extra == "dev"
95
95
  <details>
96
96
  <summary>Latest News 🔥</summary>
97
97
 
98
+ - [2025/12/19] We announced a liger kernel discord channel at https://discord.gg/X4MaxPgA; We will be hosting Liger Kernel x Triton China Meetup in mid of January 2026
98
99
  - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
99
100
  - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
100
101
  - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
@@ -113,6 +114,8 @@ We've also added optimized Post-Training kernels that deliver **up to 80% memory
113
114
 
114
115
  You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
115
116
 
117
+ You can view the Liger Kernel Technical Report: https://openreview.net/forum?id=36SjAIT42G
118
+
116
119
  ## Supercharge Your Model with Liger Kernel
117
120
 
118
121
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -177,8 +180,8 @@ y = orpo_loss(lm_head.weight, x, target)
177
180
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
178
181
 
179
182
  ```bash
180
- # Need to pass the url when installing
181
- pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
183
+ pip install -e .[dev]
184
+ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
182
185
  ```
183
186
 
184
187
  ### Optional Dependencies
@@ -212,6 +215,9 @@ pip install -e .
212
215
 
213
216
  # Setup Development Dependencies
214
217
  pip install -e ".[dev]"
218
+
219
+ # NOTE -> For AMD users only
220
+ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
215
221
  ```
216
222
 
217
223
 
@@ -289,6 +295,7 @@ loss.backward()
289
295
 
290
296
  | **Model** | **API** | **Supported Operations** |
291
297
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
298
+ | Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
292
299
  | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
293
300
  | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
294
301
  | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -302,11 +309,16 @@ loss.backward()
302
309
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
303
310
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
304
311
  | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
305
- | Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
+ | Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
306
313
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
307
314
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
308
315
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
+ | Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
309
317
  | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
+ | GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
319
+ | InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
320
+ | HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
+ | HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
310
322
 
311
323
 
312
324
  ## Low-level APIs
@@ -386,17 +398,17 @@ loss.backward()
386
398
  <td style="padding: 10px;">
387
399
  <div style="display: block;">
388
400
  <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
389
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
401
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?branch=main&event=push" alt="Build">
390
402
  </a>
391
403
  </div>
392
404
  <div style="display: block;">
393
405
  <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
394
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
406
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?branch=main&event=push" alt="Build">
395
407
  </a>
396
408
  </div>
397
409
  <div style="display: block;">
398
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
399
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
410
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
411
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?branch=main&event=push" alt="Build">
400
412
  </a>
401
413
  </div>
402
414
  </td>
@@ -409,21 +421,19 @@ loss.backward()
409
421
 
410
422
  - For issues, create a Github ticket in this repository
411
423
  - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
412
- - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
424
+ - For formal collaboration, send an email to Yanning Chen(yannchen@linkedin.com) and Zhipeng Wang(zhipwang@linkedin.com)
413
425
 
414
426
  ## Cite this work
415
427
 
416
428
  Biblatex entry:
417
429
  ```bib
418
- @article{hsu2024ligerkernelefficienttriton,
419
- title={Liger Kernel: Efficient Triton Kernels for LLM Training},
420
- author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
421
- year={2024},
422
- eprint={2410.10989},
423
- archivePrefix={arXiv},
424
- primaryClass={cs.LG},
425
- url={https://arxiv.org/abs/2410.10989},
426
- journal={arXiv preprint arXiv:2410.10989},
430
+ @inproceedings{
431
+ hsu2025ligerkernel,
432
+ title={Liger-Kernel: Efficient Triton Kernels for {LLM} Training},
433
+ author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen and Zhipeng Wang},
434
+ booktitle={Championing Open-source DEvelopment in ML Workshop @ ICML25},
435
+ year={2025},
436
+ url={https://openreview.net/forum?id=36SjAIT42G}
427
437
  }
428
438
  ```
429
439
 
@@ -435,3 +445,5 @@ Biblatex entry:
435
445
  ↑ Back to Top ↑
436
446
  </a>
437
447
  </p>
448
+
449
+