liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from torch.nn.modules.utils import _pair
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerMultiTokenAttention(nn.Module):
|
|
12
|
+
r"""
|
|
13
|
+
Multi-Token Attention:
|
|
14
|
+
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
|
|
15
|
+
|
|
16
|
+
Reference: https://arxiv.org/pdf/2504.00927
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
in_channels: int,
|
|
22
|
+
out_channels: int,
|
|
23
|
+
kernel_size: int,
|
|
24
|
+
stride: int = 1,
|
|
25
|
+
padding: int = 0,
|
|
26
|
+
dilation: int = 1,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
bias: bool = True,
|
|
29
|
+
sparse: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.in_channels = in_channels
|
|
33
|
+
self.out_channels = out_channels
|
|
34
|
+
self.kernel_size = _pair(kernel_size)
|
|
35
|
+
self.stride = _pair(stride)
|
|
36
|
+
self.padding = _pair(padding)
|
|
37
|
+
self.dilation = _pair(dilation)
|
|
38
|
+
self.groups = groups
|
|
39
|
+
self.sparse = sparse
|
|
40
|
+
|
|
41
|
+
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
|
|
42
|
+
if bias:
|
|
43
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
44
|
+
else:
|
|
45
|
+
self.register_parameter("bias", None)
|
|
46
|
+
|
|
47
|
+
self.reset_parameters()
|
|
48
|
+
|
|
49
|
+
def reset_parameters(self):
|
|
50
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
51
|
+
if self.bias is not None:
|
|
52
|
+
nn.init.zeros_(self.bias)
|
|
53
|
+
|
|
54
|
+
def forward(self, scores: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
return LigerMultiTokenAttentionFunction.apply(
|
|
56
|
+
scores,
|
|
57
|
+
self.weight,
|
|
58
|
+
self.bias,
|
|
59
|
+
self.stride,
|
|
60
|
+
self.padding,
|
|
61
|
+
self.dilation,
|
|
62
|
+
self.groups,
|
|
63
|
+
self.sparse,
|
|
64
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops import LigerPolyNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerPolyNorm(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
PolyNorm layer wrapper for Liger kernel.
|
|
10
|
+
|
|
11
|
+
PolyNorm formula:
|
|
12
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
13
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
14
|
+
|
|
15
|
+
Reference:
|
|
16
|
+
https://github.com/BryceZhuo/PolyCom/
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
eps: epsilon for numerical stability (default: 1e-6)
|
|
20
|
+
in_place: whether to in-place modify grad_output in backward to save memory (default: False).
|
|
21
|
+
Set to True to save memory if grad_output is not needed elsewhere.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, eps=1e-6, in_place=True):
|
|
25
|
+
super().__init__()
|
|
26
|
+
# Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
|
|
27
|
+
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
|
|
28
|
+
self.bias = nn.Parameter(torch.tensor(1.0))
|
|
29
|
+
self.variance_epsilon = eps
|
|
30
|
+
self.in_place = in_place
|
|
31
|
+
|
|
32
|
+
def forward(self, hidden_states):
|
|
33
|
+
return LigerPolyNormFunction.apply(
|
|
34
|
+
hidden_states,
|
|
35
|
+
self.weight,
|
|
36
|
+
self.bias,
|
|
37
|
+
self.variance_epsilon,
|
|
38
|
+
self.in_place,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def extra_repr(self):
|
|
42
|
+
return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
|
|
4
|
-
from liger_kernel.ops
|
|
4
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
@@ -13,18 +13,25 @@ class LigerRMSNorm(nn.Module):
|
|
|
13
13
|
casting_mode="llama",
|
|
14
14
|
init_fn="ones",
|
|
15
15
|
in_place=True,
|
|
16
|
+
row_mode=None,
|
|
17
|
+
elementwise_affine=True,
|
|
16
18
|
):
|
|
17
19
|
super().__init__()
|
|
18
20
|
assert init_fn in [
|
|
19
21
|
"ones",
|
|
20
22
|
"zeros",
|
|
21
23
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
22
|
-
self.
|
|
23
|
-
|
|
24
|
+
self.elementwise_affine = elementwise_affine
|
|
25
|
+
if self.elementwise_affine:
|
|
26
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
27
|
+
else:
|
|
28
|
+
self.register_parameter("weight", None)
|
|
29
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
24
30
|
eps,
|
|
25
31
|
offset,
|
|
26
32
|
casting_mode,
|
|
27
33
|
in_place,
|
|
34
|
+
row_mode,
|
|
28
35
|
)
|
|
29
36
|
|
|
30
37
|
def forward(self, hidden_states):
|
|
@@ -35,9 +42,50 @@ class LigerRMSNorm(nn.Module):
|
|
|
35
42
|
self.offset,
|
|
36
43
|
self.casting_mode,
|
|
37
44
|
self.in_place,
|
|
45
|
+
self.row_mode,
|
|
38
46
|
)
|
|
39
47
|
|
|
40
48
|
def extra_repr(self):
|
|
41
|
-
return (
|
|
42
|
-
|
|
43
|
-
|
|
49
|
+
return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
|
53
|
+
def __init__(
|
|
54
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
|
55
|
+
):
|
|
56
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
|
60
|
+
def __init__(
|
|
61
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
62
|
+
):
|
|
63
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
67
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
68
|
+
|
|
69
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
70
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
|
74
|
+
def __init__(
|
|
75
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
76
|
+
):
|
|
77
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
|
81
|
+
def __init__(
|
|
82
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
83
|
+
):
|
|
84
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LigerRMSNormForQwen3Next(LigerRMSNorm):
|
|
88
|
+
def __init__(
|
|
89
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
90
|
+
):
|
|
91
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops import LigerRopeFunction
|
|
2
6
|
|
|
3
7
|
|
|
4
8
|
def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
@@ -18,3 +22,43 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
24
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def liger_rotary_pos_emb_vision(
|
|
28
|
+
q: torch.Tensor,
|
|
29
|
+
k: torch.Tensor,
|
|
30
|
+
cos: torch.Tensor,
|
|
31
|
+
sin: torch.Tensor,
|
|
32
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
"""
|
|
34
|
+
Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
|
|
35
|
+
Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
|
|
36
|
+
Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
40
|
+
with stride (num_heads * head_dim, head_dim, 1).
|
|
41
|
+
k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
42
|
+
with stride (num_heads * head_dim, head_dim, 1). Same as q.
|
|
43
|
+
cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
|
|
44
|
+
sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
|
|
48
|
+
"""
|
|
49
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
50
|
+
|
|
51
|
+
# tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
|
|
52
|
+
# also unsqueeze for batch dim
|
|
53
|
+
q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
54
|
+
k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
55
|
+
cos32 = cos.to(torch.float32)
|
|
56
|
+
sin32 = sin.to(torch.float32)
|
|
57
|
+
|
|
58
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
|
|
59
|
+
|
|
60
|
+
# transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
|
|
61
|
+
# also squeeze out batch dim
|
|
62
|
+
q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
|
|
63
|
+
k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
|
|
64
|
+
return q_out, k_out
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops import LigerSparsemaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSparsemax(nn.Module):
|
|
8
|
+
def __init__(self, dim: int = -1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.dim = dim
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
|
14
|
+
|
|
15
|
+
def extra_repr(self) -> str:
|
|
16
|
+
return f"dim={self.dim}"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch.nn as nn
|
|
2
2
|
|
|
3
|
-
from liger_kernel.ops
|
|
3
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class LigerSwiGLUMLP(nn.Module):
|
|
@@ -56,3 +56,41 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
56
56
|
up_states = self.gate_up_proj(x)
|
|
57
57
|
gate, up_states = up_states.chunk(2, dim=-1)
|
|
58
58
|
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
|
|
64
|
+
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config, intermediate_size=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.hidden_size = config.hidden_size
|
|
71
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
72
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
73
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
74
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
75
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
76
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LigerHunyuanV1SwiGLUMLP(nn.Module):
|
|
83
|
+
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.config = config
|
|
86
|
+
self.hidden_size = config.hidden_size
|
|
87
|
+
self.intermediate_size = config.intermediate_size
|
|
88
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
89
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
90
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
91
|
+
self.layer_idx = layer_idx
|
|
92
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
93
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
94
|
+
|
|
95
|
+
def forward(self, x):
|
|
96
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
6
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
7
|
+
from liger_kernel.ops import apply_tiled_mlp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerTiledGEGLUMLP(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Memory-efficient GEGLU MLP using tiled computation.
|
|
13
|
+
|
|
14
|
+
This module combines GEGLU activation with tiled processing to handle
|
|
15
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
16
|
+
backward to save memory.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
20
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
21
|
+
calculated as ceil(seqlen / hidden_size)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.config = config
|
|
27
|
+
self.hidden_size = config.hidden_size
|
|
28
|
+
self.intermediate_size = config.intermediate_size
|
|
29
|
+
self.num_shards = num_shards
|
|
30
|
+
|
|
31
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
32
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
33
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
34
|
+
|
|
35
|
+
# Validate activation function
|
|
36
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in [
|
|
37
|
+
"gelu",
|
|
38
|
+
"gelu_new",
|
|
39
|
+
"gelu_pytorch_tanh",
|
|
40
|
+
]:
|
|
41
|
+
raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
|
|
42
|
+
|
|
43
|
+
def _mlp_forward(self, module, x):
|
|
44
|
+
"""Internal MLP forward function for tiled computation."""
|
|
45
|
+
gate = module.gate_proj(x)
|
|
46
|
+
up = module.up_proj(x)
|
|
47
|
+
return module.down_proj(LigerGELUMulFunction.apply(gate, up))
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
"""
|
|
51
|
+
Forward pass with tiled computation.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
55
|
+
or [seq_len, hidden_size]
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Output tensor of the same shape as input
|
|
59
|
+
"""
|
|
60
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
61
|
+
|
|
62
|
+
return apply_tiled_mlp(
|
|
63
|
+
fn=self._mlp_forward,
|
|
64
|
+
mlp_module=self,
|
|
65
|
+
x=x,
|
|
66
|
+
num_shards=self.num_shards,
|
|
67
|
+
compute_params=compute_params,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LigerTiledSwiGLUMLP(nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Memory-efficient SwiGLU MLP using tiled computation.
|
|
74
|
+
|
|
75
|
+
This module combines SwiGLU activation with tiled processing to handle
|
|
76
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
77
|
+
backward to save memory.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
81
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
82
|
+
calculated as ceil(seqlen / hidden_size)
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.config = config
|
|
88
|
+
self.hidden_size = config.hidden_size
|
|
89
|
+
self.intermediate_size = config.intermediate_size
|
|
90
|
+
self.num_shards = num_shards
|
|
91
|
+
|
|
92
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
93
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
94
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
95
|
+
|
|
96
|
+
# Validate activation function
|
|
97
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
|
|
98
|
+
raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
|
|
99
|
+
|
|
100
|
+
def _mlp_forward(self, module, x):
|
|
101
|
+
"""Internal MLP forward function for tiled computation."""
|
|
102
|
+
gate = module.gate_proj(x)
|
|
103
|
+
up = module.up_proj(x)
|
|
104
|
+
return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
|
|
105
|
+
|
|
106
|
+
def forward(self, x):
|
|
107
|
+
"""
|
|
108
|
+
Forward pass with tiled computation.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
112
|
+
or [seq_len, hidden_size]
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Output tensor of the same shape as input
|
|
116
|
+
"""
|
|
117
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
118
|
+
|
|
119
|
+
return apply_tiled_mlp(
|
|
120
|
+
fn=self._mlp_forward,
|
|
121
|
+
mlp_module=self,
|
|
122
|
+
x=x,
|
|
123
|
+
num_shards=self.num_shards,
|
|
124
|
+
compute_params=compute_params,
|
|
125
|
+
)
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
from typing import Callable
|
|
3
1
|
from typing import Dict
|
|
4
2
|
from typing import List
|
|
5
3
|
from typing import Literal
|
|
@@ -13,57 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel
|
|
|
13
11
|
from trl.trainer import ORPOTrainer
|
|
14
12
|
|
|
15
13
|
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class _FSDPForwardRedirection:
|
|
19
|
-
"""
|
|
20
|
-
Modified based on
|
|
21
|
-
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
22
|
-
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
23
|
-
post-forward can be properly executed around the method call.
|
|
24
|
-
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
25
|
-
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
26
|
-
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
27
|
-
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
|
|
28
|
-
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
29
|
-
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
30
|
-
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __call__(
|
|
34
|
-
self,
|
|
35
|
-
wrapper_module: FullyShardedDataParallel,
|
|
36
|
-
method: Callable,
|
|
37
|
-
*args: Any,
|
|
38
|
-
**kwargs: Any,
|
|
39
|
-
):
|
|
40
|
-
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
41
|
-
Args:
|
|
42
|
-
wrapper_module: The module that has `original_module` wrapped.
|
|
43
|
-
original_module: The module that was wrapped inside `wrapper_module`.
|
|
44
|
-
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
45
|
-
redirected through the `wrapper_module`'s `forward` method.
|
|
46
|
-
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
47
|
-
`forward` method instead.
|
|
48
|
-
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
49
|
-
`forward` method instead.
|
|
50
|
-
"""
|
|
51
|
-
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
52
|
-
original_module = wrapper_module._fsdp_wrapped_module
|
|
53
|
-
original_forward = original_module.forward
|
|
54
|
-
|
|
55
|
-
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
56
|
-
# Unpatch ourselves immediately before calling the method `method_name`
|
|
57
|
-
# because itself may want to call the real `forward`
|
|
58
|
-
original_module.forward = original_forward # type: ignore[method-assign]
|
|
59
|
-
# Call the actual method e.g. `.training_step(...)`
|
|
60
|
-
out = method(*_args, **_kwargs)
|
|
61
|
-
return out
|
|
62
|
-
|
|
63
|
-
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
64
|
-
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
65
|
-
wrapper_output = wrapper_module(*args, **kwargs)
|
|
66
|
-
return wrapper_output
|
|
14
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
67
15
|
|
|
68
16
|
|
|
69
17
|
class LigerORPOTrainer(ORPOTrainer):
|
liger_kernel/transformers/tvd.py
CHANGED
liger_kernel/utils.py
CHANGED
|
@@ -1,18 +1,81 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import peft # noqa: F401
|
|
3
|
+
|
|
4
|
+
PEFT_AVAILABLE = True
|
|
5
|
+
except ImportError:
|
|
6
|
+
PEFT_AVAILABLE = False
|
|
7
|
+
|
|
1
8
|
import torch
|
|
2
9
|
|
|
3
10
|
|
|
11
|
+
def is_peft_available():
|
|
12
|
+
return PEFT_AVAILABLE
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def infer_comm_backend():
|
|
16
|
+
"""
|
|
17
|
+
Get communication backend name based on the environment.
|
|
18
|
+
"""
|
|
19
|
+
if torch.distributed.is_nccl_available():
|
|
20
|
+
# Works for Nvidia
|
|
21
|
+
# TODO: nccl may not work for AMD decices that may require use of rccl.
|
|
22
|
+
return "nccl"
|
|
23
|
+
elif is_npu_available():
|
|
24
|
+
# Use Ascend NPU if available (torch.npu)
|
|
25
|
+
# Ascend is not standard torch backend and requires extension.
|
|
26
|
+
# Assume that it is installed if NPUs are being used in
|
|
27
|
+
# multi device environment.
|
|
28
|
+
return "ascend"
|
|
29
|
+
# XPU (Intel) if available
|
|
30
|
+
elif torch.distributed.distributed_c10d.is_xccl_available():
|
|
31
|
+
return "xccl"
|
|
32
|
+
elif torch.distributed.is_mpi_available():
|
|
33
|
+
# CPU backend, first option
|
|
34
|
+
return "mpi"
|
|
35
|
+
elif torch.distributed.is_gloo_available():
|
|
36
|
+
# CPU backend, backup option
|
|
37
|
+
return "gloo"
|
|
38
|
+
else:
|
|
39
|
+
raise RuntimeError("There is no distributed backend available.")
|
|
40
|
+
|
|
41
|
+
|
|
4
42
|
def infer_device():
|
|
5
43
|
"""
|
|
6
44
|
Get current device name based on available devices
|
|
7
45
|
"""
|
|
8
46
|
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
|
9
47
|
return "cuda"
|
|
48
|
+
# Use Ascend NPU if available (torch.npu)
|
|
49
|
+
elif is_npu_available():
|
|
50
|
+
return "npu"
|
|
51
|
+
# XPU (Intel) if available
|
|
10
52
|
elif torch.xpu.is_available():
|
|
11
53
|
return "xpu"
|
|
12
54
|
else:
|
|
13
55
|
return "cpu"
|
|
14
56
|
|
|
15
57
|
|
|
58
|
+
def is_npu_available() -> bool:
|
|
59
|
+
"""Detect Ascend NPU availability."""
|
|
60
|
+
try:
|
|
61
|
+
from transformers.utils import is_torch_npu_available
|
|
62
|
+
|
|
63
|
+
return is_torch_npu_available()
|
|
64
|
+
except Exception:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_npu_multi_processor_count() -> int:
|
|
69
|
+
"""Return a heuristic multi-processor count for NPU."""
|
|
70
|
+
if is_npu_available():
|
|
71
|
+
NPU_MULTI_PROCESSOR_COUNT = 48
|
|
72
|
+
dev_props = torch.npu.get_device_properties()
|
|
73
|
+
# The vector_core_num attribute is supported in the torch.npu v7.2.0 release version.
|
|
74
|
+
return dev_props.vector_core_num if hasattr(dev_props, "vector_core_num") else NPU_MULTI_PROCESSOR_COUNT
|
|
75
|
+
# Reasonable default to avoid division by zero
|
|
76
|
+
return 1
|
|
77
|
+
|
|
78
|
+
|
|
16
79
|
def transformers_version_dispatch(
|
|
17
80
|
required_version: str,
|
|
18
81
|
before_fn,
|