liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
9
+
10
+
11
+ class LigerMultiTokenAttention(nn.Module):
12
+ r"""
13
+ Multi-Token Attention:
14
+ out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
+
16
+ Reference: https://arxiv.org/pdf/2504.00927
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ kernel_size: int,
24
+ stride: int = 1,
25
+ padding: int = 0,
26
+ dilation: int = 1,
27
+ groups: int = 1,
28
+ bias: bool = True,
29
+ sparse: bool = False,
30
+ ):
31
+ super().__init__()
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.kernel_size = _pair(kernel_size)
35
+ self.stride = _pair(stride)
36
+ self.padding = _pair(padding)
37
+ self.dilation = _pair(dilation)
38
+ self.groups = groups
39
+ self.sparse = sparse
40
+
41
+ self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
42
+ if bias:
43
+ self.bias = nn.Parameter(torch.empty(out_channels))
44
+ else:
45
+ self.register_parameter("bias", None)
46
+
47
+ self.reset_parameters()
48
+
49
+ def reset_parameters(self):
50
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
51
+ if self.bias is not None:
52
+ nn.init.zeros_(self.bias)
53
+
54
+ def forward(self, scores: torch.Tensor) -> torch.Tensor:
55
+ return LigerMultiTokenAttentionFunction.apply(
56
+ scores,
57
+ self.weight,
58
+ self.bias,
59
+ self.stride,
60
+ self.padding,
61
+ self.dilation,
62
+ self.groups,
63
+ self.sparse,
64
+ )
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
5
+
6
+
7
+ class LigerPolyNorm(nn.Module):
8
+ """
9
+ PolyNorm layer wrapper for Liger kernel.
10
+
11
+ PolyNorm formula:
12
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
13
+ where norm(u) = u / sqrt(mean(u²) + ε)
14
+
15
+ Reference:
16
+ https://github.com/BryceZhuo/PolyCom/
17
+
18
+ Args:
19
+ eps: epsilon for numerical stability (default: 1e-6)
20
+ in_place: whether to in-place modify grad_output in backward to save memory (default: False).
21
+ Set to True to save memory if grad_output is not needed elsewhere.
22
+ """
23
+
24
+ def __init__(self, eps=1e-6, in_place=True):
25
+ super().__init__()
26
+ # Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
27
+ self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
28
+ self.bias = nn.Parameter(torch.tensor(1.0))
29
+ self.variance_epsilon = eps
30
+ self.in_place = in_place
31
+
32
+ def forward(self, hidden_states):
33
+ return LigerPolyNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.bias,
37
+ self.variance_epsilon,
38
+ self.in_place,
39
+ )
40
+
41
+ def extra_repr(self):
42
+ return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
@@ -0,0 +1,20 @@
1
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2
+
3
+
4
+ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5
+ """
6
+ Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7
+
8
+ Args:
9
+ q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
+ k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
13
+ mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
+ unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
+
16
+ Returns:
17
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18
+ """
19
+
20
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
@@ -6,20 +6,27 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
8
  def __init__(
9
- self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=True,
16
+ row_mode=None,
10
17
  ):
11
18
  super().__init__()
12
19
  assert init_fn in [
13
20
  "ones",
14
21
  "zeros",
15
22
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
16
- self.weight = nn.Parameter(
17
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
- )
19
- self.variance_epsilon, self.offset, self.casting_mode = (
23
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
24
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
20
25
  eps,
21
26
  offset,
22
27
  casting_mode,
28
+ in_place,
29
+ row_mode,
23
30
  )
24
31
 
25
32
  def forward(self, hidden_states):
@@ -29,7 +36,51 @@ class LigerRMSNorm(nn.Module):
29
36
  self.variance_epsilon,
30
37
  self.offset,
31
38
  self.casting_mode,
39
+ self.in_place,
40
+ self.row_mode,
32
41
  )
33
42
 
34
43
  def extra_repr(self):
35
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
44
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
45
+
46
+
47
+ class LigerRMSNormForGemma(LigerRMSNorm):
48
+ def __init__(
49
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
50
+ ):
51
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
52
+
53
+
54
+ class LigerRMSNormForGemma2(LigerRMSNorm):
55
+ def __init__(
56
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
57
+ ):
58
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
59
+
60
+
61
+ class LigerRMSNormForGemma3(LigerRMSNorm):
62
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
63
+
64
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
65
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
66
+
67
+
68
+ class LigerRMSNormForOlmo2(LigerRMSNorm):
69
+ def __init__(
70
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
71
+ ):
72
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
73
+
74
+
75
+ class LigerRMSNormForGlm4(LigerRMSNorm):
76
+ def __init__(
77
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
78
+ ):
79
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
80
+
81
+
82
+ class LigerRMSNormForQwen3Next(LigerRMSNorm):
83
+ def __init__(
84
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
85
+ ):
86
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,3 +1,8 @@
1
+ from typing import Optional
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
1
6
  from liger_kernel.ops.rope import LigerRopeFunction
2
7
 
3
8
 
@@ -8,8 +13,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
13
  Args:
9
14
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
15
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
16
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
17
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
18
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
19
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
20
 
@@ -18,3 +23,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
18
23
  """
19
24
 
20
25
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
26
+
27
+
28
+ def liger_rotary_pos_emb_with_cast(
29
+ q: torch.Tensor,
30
+ k: torch.Tensor,
31
+ cos: torch.Tensor,
32
+ sin: torch.Tensor,
33
+ position_ids: Optional[torch.Tensor] = None,
34
+ unsqueeze_dim: int = 1,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
37
+
38
+ q32 = q.to(torch.float32)
39
+ k32 = k.to(torch.float32)
40
+ cos32 = cos.to(torch.float32)
41
+ sin32 = sin.to(torch.float32)
42
+
43
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
44
+ return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
45
+
46
+
47
+ def liger_rotary_pos_emb_with_cast_and_leading_batch(
48
+ q: torch.Tensor,
49
+ k: torch.Tensor,
50
+ cos: torch.Tensor,
51
+ sin: torch.Tensor,
52
+ position_ids: Optional[torch.Tensor] = None,
53
+ unsqueeze_dim: int = 1,
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
56
+
57
+ q32 = q.to(torch.float32).unsqueeze(0)
58
+ k32 = k.to(torch.float32).unsqueeze(0)
59
+ cos32 = cos.to(torch.float32).unsqueeze(0)
60
+ sin32 = sin.to(torch.float32).unsqueeze(0)
61
+
62
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
63
+ return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
@@ -0,0 +1,12 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction
5
+
6
+
7
+ class LigerSoftmax(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x: torch.Tensor):
12
+ return LigerSoftmaxFunction.apply(x)
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
16
16
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
17
17
 
18
18
  def forward(self, x):
19
-
20
- return self.down_proj(
21
- LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22
- )
19
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
23
20
 
24
21
 
25
22
  class LigerBlockSparseTop2MLP(nn.Module):
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
36
33
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
37
34
 
38
35
  def forward(self, x):
39
-
40
36
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
37
 
42
38
 
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
51
47
  self.config = config
52
48
  self.hidden_size = config.hidden_size
53
49
  self.intermediate_size = config.intermediate_size
54
- self.gate_up_proj = nn.Linear(
55
- self.hidden_size, 2 * self.intermediate_size, bias=False
56
- )
50
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
57
51
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
52
  if config.hidden_act not in ["silu", "swish"]:
59
53
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
@@ -62,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
62
56
  up_states = self.gate_up_proj(x)
63
57
  gate, up_states = up_states.chunk(2, dim=-1)
64
58
  return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
59
+
60
+
61
+ class LigerQwen3MoeSwiGLUMLP(nn.Module):
62
+ """
63
+ Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
64
+ https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
65
+ """
66
+
67
+ def __init__(self, config, intermediate_size=None):
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
72
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
+ if config.hidden_act not in ["silu", "swish"]:
76
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
77
+
78
+ def forward(self, x):
79
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -0,0 +1,133 @@
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+
5
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
6
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
7
+ from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
8
+
9
+
10
+ class LigerTiledGEGLUMLP(nn.Module):
11
+ """
12
+ Memory-efficient GEGLU MLP using tiled computation.
13
+
14
+ This module combines GEGLU activation with tiled processing to handle
15
+ very long sequences efficiently. The forward pass is recomputed during
16
+ backward to save memory.
17
+
18
+ Args:
19
+ config: Model configuration with hidden_size and intermediate_size attributes
20
+ num_shards: Number of shards to split the sequence. If None, automatically
21
+ calculated as ceil(seqlen / hidden_size)
22
+ """
23
+
24
+ def __init__(self, config, num_shards: Optional[int] = None):
25
+ super().__init__()
26
+ self.config = config
27
+ self.hidden_size = config.hidden_size
28
+ self.intermediate_size = config.intermediate_size
29
+ self.num_shards = num_shards
30
+
31
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
32
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
33
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
34
+
35
+ # Validate activation function
36
+ if hasattr(config, "hidden_act") and config.hidden_act not in [
37
+ "gelu",
38
+ "gelu_new",
39
+ "gelu_pytorch_tanh",
40
+ ]:
41
+ raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
42
+
43
+ def _mlp_forward(self, module, x):
44
+ """Internal MLP forward function for tiled computation."""
45
+ gate = module.gate_proj(x)
46
+ up = module.up_proj(x)
47
+ return module.down_proj(LigerGELUMulFunction.apply(gate, up))
48
+
49
+ def forward(self, x):
50
+ """
51
+ Forward pass with tiled computation.
52
+
53
+ Args:
54
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
55
+ or [seq_len, hidden_size]
56
+
57
+ Returns:
58
+ Output tensor of the same shape as input
59
+ """
60
+ compute_params = [
61
+ self.gate_proj.weight,
62
+ self.up_proj.weight,
63
+ self.down_proj.weight,
64
+ ]
65
+
66
+ return apply_tiled_mlp(
67
+ fn=self._mlp_forward,
68
+ mlp_module=self,
69
+ x=x,
70
+ num_shards=self.num_shards,
71
+ compute_params=compute_params,
72
+ )
73
+
74
+
75
+ class LigerTiledSwiGLUMLP(nn.Module):
76
+ """
77
+ Memory-efficient SwiGLU MLP using tiled computation.
78
+
79
+ This module combines SwiGLU activation with tiled processing to handle
80
+ very long sequences efficiently. The forward pass is recomputed during
81
+ backward to save memory.
82
+
83
+ Args:
84
+ config: Model configuration with hidden_size and intermediate_size attributes
85
+ num_shards: Number of shards to split the sequence. If None, automatically
86
+ calculated as ceil(seqlen / hidden_size)
87
+ """
88
+
89
+ def __init__(self, config, num_shards: Optional[int] = None):
90
+ super().__init__()
91
+ self.config = config
92
+ self.hidden_size = config.hidden_size
93
+ self.intermediate_size = config.intermediate_size
94
+ self.num_shards = num_shards
95
+
96
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
97
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
98
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
99
+
100
+ # Validate activation function
101
+ if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
102
+ raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
103
+
104
+ def _mlp_forward(self, module, x):
105
+ """Internal MLP forward function for tiled computation."""
106
+ gate = module.gate_proj(x)
107
+ up = module.up_proj(x)
108
+ return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
109
+
110
+ def forward(self, x):
111
+ """
112
+ Forward pass with tiled computation.
113
+
114
+ Args:
115
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
116
+ or [seq_len, hidden_size]
117
+
118
+ Returns:
119
+ Output tensor of the same shape as input
120
+ """
121
+ compute_params = [
122
+ self.gate_proj.weight,
123
+ self.up_proj.weight,
124
+ self.down_proj.weight,
125
+ ]
126
+
127
+ return apply_tiled_mlp(
128
+ fn=self._mlp_forward,
129
+ mlp_module=self,
130
+ x=x,
131
+ num_shards=self.num_shards,
132
+ compute_params=compute_params,
133
+ )
@@ -0,0 +1,4 @@
1
+ try:
2
+ from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
3
+ except ImportError:
4
+ raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
@@ -0,0 +1,130 @@
1
+ from typing import Dict
2
+ from typing import List
3
+ from typing import Literal
4
+ from typing import Tuple
5
+ from typing import Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
11
+ from trl.trainer import ORPOTrainer
12
+
13
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
14
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
15
+
16
+
17
+ class LigerORPOTrainer(ORPOTrainer):
18
+ def concatenated_forward(
19
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
20
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
21
+ """
22
+ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
23
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
24
+ """
25
+ concatenated_batch = self.concatenated_inputs(
26
+ batch,
27
+ is_encoder_decoder=self.is_encoder_decoder,
28
+ label_pad_token_id=self.label_pad_token_id,
29
+ padding_value=self.padding_value,
30
+ device=self.accelerator.device,
31
+ )
32
+
33
+ model_kwargs = (
34
+ {
35
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
36
+ }
37
+ if self.is_encoder_decoder
38
+ else {}
39
+ )
40
+
41
+ if self.aux_loss_enabled:
42
+ model_kwargs["output_router_logits"] = True
43
+
44
+ if self.is_encoder_decoder:
45
+ labels = concatenated_batch["concatenated_labels"].clone()
46
+ else:
47
+ labels = concatenated_batch["concatenated_input_ids"].clone()
48
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
49
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
50
+
51
+ if isinstance(model, FullyShardedDataParallel):
52
+ outputs = _FSDPForwardRedirection()(
53
+ model,
54
+ model._fsdp_wrapped_module.model,
55
+ concatenated_batch["concatenated_input_ids"],
56
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
57
+ use_cache=False,
58
+ **model_kwargs,
59
+ )
60
+ else:
61
+ if isinstance(model, torch.nn.DataParallel):
62
+ model = model.module
63
+ outputs = model.model(
64
+ concatenated_batch["concatenated_input_ids"],
65
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
66
+ use_cache=False,
67
+ **model_kwargs,
68
+ )
69
+
70
+ orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
71
+
72
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
73
+ return orpo_loss_fn(
74
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
75
+ )
76
+
77
+ orpo_loss, aux_outputs = _FSDPForwardRedirection()(
78
+ model,
79
+ orpo_partial,
80
+ model.lm_head,
81
+ outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
82
+ concatenated_batch["concatenated_labels"][:, 1:]
83
+ if not self.is_encoder_decoder
84
+ else concatenated_batch["concatenated_labels"],
85
+ labels[:, 1:] if not self.is_encoder_decoder else labels,
86
+ )
87
+ # if aux_loss_enabled, add the aux_loss to the orpo_loss
88
+ if self.aux_loss_enabled:
89
+ orpo_loss += self.aux_loss_coef * outputs.aux_loss
90
+
91
+ return orpo_loss, aux_outputs
92
+
93
+ def get_batch_loss_metrics(
94
+ self,
95
+ model,
96
+ batch: Dict[str, Union[List, torch.LongTensor]],
97
+ train_eval: Literal["train", "eval"] = "train",
98
+ ):
99
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
100
+ metrics = {}
101
+ loss, aux_outputs = self.concatenated_forward(model, batch)
102
+ (
103
+ policy_chosen_logps,
104
+ policy_rejected_logps,
105
+ policy_chosen_logits,
106
+ policy_rejected_logits,
107
+ policy_nll_loss,
108
+ ) = aux_outputs[:5]
109
+
110
+ # return loss, metrics
111
+ chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
112
+
113
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
114
+
115
+ prefix = "eval_" if train_eval == "eval" else ""
116
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
117
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
118
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
119
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
120
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
121
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
122
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
123
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
124
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
125
+ metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
126
+ metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
127
+ for k, v in metrics.items():
128
+ metrics[k] = v.item()
129
+
130
+ return loss, metrics
@@ -0,0 +1,13 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.tvd import LigerTVDLossFunction
4
+
5
+
6
+ class LigerTVDLoss(nn.Module):
7
+ def __init__(self, reduction="batchmean", ignore_index: int = -100):
8
+ super(LigerTVDLoss, self).__init__()
9
+ self.reduction = reduction
10
+ self.ignore_index = ignore_index
11
+
12
+ def forward(self, p, q, shift_labels=None):
13
+ return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
@@ -1,3 +1 @@
1
- from liger_kernel.triton.monkey_patch import ( # noqa: F401
2
- apply_liger_triton_cache_manager,
3
- )
1
+ from liger_kernel.triton.monkey_patch import apply_liger_triton_cache_manager # noqa: F401
@@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager():
37
37
  Experimental feature to get around transient FileNotFoundError in triton compilation.
38
38
  For more details please see https://github.com/triton-lang/triton/pull/4295
39
39
  """
40
- os.environ["TRITON_CACHE_MANAGER"] = (
41
- "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
42
- )
40
+ os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"