liger-kernel 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +59 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +30 -4
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +84 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +19 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +24 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +398 -20
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +4 -1
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/RECORD +55 -48
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
|
@@ -7,11 +7,12 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
12
11
|
|
|
13
12
|
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
14
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
15
16
|
from liger_kernel.utils import PEFT_AVAILABLE
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
@@ -38,7 +39,7 @@ def lce_forward(
|
|
|
38
39
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
39
40
|
skip_logits: Optional[bool] = None,
|
|
40
41
|
**kwargs,
|
|
41
|
-
) -> Union[Tuple,
|
|
42
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
42
43
|
r"""
|
|
43
44
|
Args:
|
|
44
45
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -101,6 +102,8 @@ def lce_forward(
|
|
|
101
102
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
102
103
|
logits = None
|
|
103
104
|
loss = None
|
|
105
|
+
token_accuracy = None
|
|
106
|
+
|
|
104
107
|
# if in training mode, don't materialize logits
|
|
105
108
|
if skip_logits and labels is None and shift_labels is None:
|
|
106
109
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -109,8 +112,9 @@ def lce_forward(
|
|
|
109
112
|
# By default, if in training mode, don't materialize logits
|
|
110
113
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
111
114
|
|
|
115
|
+
# Compute loss
|
|
112
116
|
if skip_logits:
|
|
113
|
-
|
|
117
|
+
result = lce_maybe_trainable_lm_head(
|
|
114
118
|
self,
|
|
115
119
|
hidden_states=kept_hidden_states,
|
|
116
120
|
hidden_size=self.config.hidden_size,
|
|
@@ -118,6 +122,7 @@ def lce_forward(
|
|
|
118
122
|
shift_labels=shift_labels,
|
|
119
123
|
**kwargs,
|
|
120
124
|
)
|
|
125
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
121
126
|
|
|
122
127
|
else:
|
|
123
128
|
logits = self.lm_head(kept_hidden_states)
|
|
@@ -131,15 +136,19 @@ def lce_forward(
|
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if not return_dict:
|
|
134
|
-
|
|
135
|
-
|
|
139
|
+
output_tuple = (logits,) + outputs[1:]
|
|
140
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
141
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
142
|
+
return output
|
|
136
143
|
|
|
137
|
-
|
|
144
|
+
# Return custom output class with token_accuracy field
|
|
145
|
+
return LigerCausalLMOutputWithPast(
|
|
138
146
|
loss=loss,
|
|
139
147
|
logits=logits,
|
|
140
148
|
past_key_values=outputs.past_key_values,
|
|
141
149
|
hidden_states=outputs.hidden_states,
|
|
142
150
|
attentions=outputs.attentions,
|
|
151
|
+
token_accuracy=token_accuracy,
|
|
143
152
|
)
|
|
144
153
|
|
|
145
154
|
|
|
@@ -34,6 +34,8 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
|
|
|
34
34
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
35
35
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
36
36
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
37
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
37
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
38
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
39
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -1643,6 +1645,158 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1643
1645
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1644
1646
|
|
|
1645
1647
|
|
|
1648
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1649
|
+
rope: bool = True,
|
|
1650
|
+
cross_entropy: bool = False,
|
|
1651
|
+
fused_linear_cross_entropy: bool = True,
|
|
1652
|
+
rms_norm: bool = True,
|
|
1653
|
+
swiglu: bool = False,
|
|
1654
|
+
model: PreTrainedModel = None,
|
|
1655
|
+
) -> None:
|
|
1656
|
+
"""
|
|
1657
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1658
|
+
|
|
1659
|
+
Args:
|
|
1660
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1661
|
+
fused_linear_cross_entropy (bool):
|
|
1662
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1663
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1664
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1665
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1666
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1667
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1668
|
+
loaded. Default is None.
|
|
1669
|
+
"""
|
|
1670
|
+
|
|
1671
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1672
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1676
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1677
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1678
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1679
|
+
|
|
1680
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1681
|
+
|
|
1682
|
+
if rope:
|
|
1683
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1684
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1685
|
+
|
|
1686
|
+
if rms_norm:
|
|
1687
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1688
|
+
|
|
1689
|
+
if cross_entropy:
|
|
1690
|
+
from transformers.loss.loss_utils import nn
|
|
1691
|
+
|
|
1692
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1693
|
+
|
|
1694
|
+
if fused_linear_cross_entropy:
|
|
1695
|
+
if model is not None:
|
|
1696
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1697
|
+
else:
|
|
1698
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1699
|
+
|
|
1700
|
+
if model is not None and rms_norm:
|
|
1701
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1702
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1703
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1704
|
+
text_model = model
|
|
1705
|
+
else:
|
|
1706
|
+
raise TypeError(
|
|
1707
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1711
|
+
|
|
1712
|
+
if text_model is not None:
|
|
1713
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1714
|
+
for decoder_layer in text_model.layers:
|
|
1715
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1716
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1717
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1718
|
+
if self_attn is not None:
|
|
1719
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1720
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1721
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1722
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1726
|
+
rope: bool = True,
|
|
1727
|
+
cross_entropy: bool = False,
|
|
1728
|
+
fused_linear_cross_entropy: bool = True,
|
|
1729
|
+
rms_norm: bool = True,
|
|
1730
|
+
swiglu: bool = False,
|
|
1731
|
+
model: PreTrainedModel = None,
|
|
1732
|
+
) -> None:
|
|
1733
|
+
"""
|
|
1734
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1735
|
+
|
|
1736
|
+
Args:
|
|
1737
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1738
|
+
fused_linear_cross_entropy (bool):
|
|
1739
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1740
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1741
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1742
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1743
|
+
loaded. Default is None.
|
|
1744
|
+
"""
|
|
1745
|
+
|
|
1746
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1747
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1751
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1752
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1753
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1754
|
+
|
|
1755
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1756
|
+
|
|
1757
|
+
if rope:
|
|
1758
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1759
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1760
|
+
|
|
1761
|
+
if rms_norm:
|
|
1762
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1763
|
+
|
|
1764
|
+
if cross_entropy:
|
|
1765
|
+
from transformers.loss.loss_utils import nn
|
|
1766
|
+
|
|
1767
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1768
|
+
|
|
1769
|
+
if fused_linear_cross_entropy:
|
|
1770
|
+
if model is not None:
|
|
1771
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1772
|
+
else:
|
|
1773
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1774
|
+
|
|
1775
|
+
if model is not None and rms_norm:
|
|
1776
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1777
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1778
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1779
|
+
text_model = model
|
|
1780
|
+
else:
|
|
1781
|
+
raise TypeError(
|
|
1782
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1783
|
+
)
|
|
1784
|
+
|
|
1785
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1786
|
+
|
|
1787
|
+
if text_model is not None:
|
|
1788
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1789
|
+
for decoder_layer in text_model.layers:
|
|
1790
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1791
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1792
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1793
|
+
if self_attn is not None:
|
|
1794
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1796
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1797
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1798
|
+
|
|
1799
|
+
|
|
1646
1800
|
def apply_liger_kernel_to_phi3(
|
|
1647
1801
|
rope: bool = True,
|
|
1648
1802
|
cross_entropy: bool = False,
|
|
@@ -1774,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1774
1928
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1775
1929
|
|
|
1776
1930
|
|
|
1931
|
+
def apply_liger_kernel_to_olmo3(
|
|
1932
|
+
rope: bool = True,
|
|
1933
|
+
cross_entropy: bool = False,
|
|
1934
|
+
fused_linear_cross_entropy: bool = True,
|
|
1935
|
+
rms_norm: bool = True,
|
|
1936
|
+
swiglu: bool = True,
|
|
1937
|
+
model: PreTrainedModel = None,
|
|
1938
|
+
) -> None:
|
|
1939
|
+
"""
|
|
1940
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1944
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1945
|
+
fused_linear_cross_entropy (bool):
|
|
1946
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1947
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1948
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1949
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1950
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
1951
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1952
|
+
loaded. Default is None.
|
|
1953
|
+
"""
|
|
1954
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1955
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1956
|
+
)
|
|
1957
|
+
|
|
1958
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
1959
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
1960
|
+
|
|
1961
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
1962
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1963
|
+
|
|
1964
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
1965
|
+
if rope:
|
|
1966
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1967
|
+
if rms_norm:
|
|
1968
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
1969
|
+
if swiglu:
|
|
1970
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
1971
|
+
if cross_entropy:
|
|
1972
|
+
from transformers.loss.loss_utils import nn
|
|
1973
|
+
|
|
1974
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1975
|
+
if fused_linear_cross_entropy:
|
|
1976
|
+
if model is not None:
|
|
1977
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
1978
|
+
else:
|
|
1979
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
1980
|
+
|
|
1981
|
+
if model is not None:
|
|
1982
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1983
|
+
# instance variables that reference already-instantiated modules
|
|
1984
|
+
|
|
1985
|
+
# get the base model from the model instance
|
|
1986
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
1987
|
+
|
|
1988
|
+
if rms_norm:
|
|
1989
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1990
|
+
|
|
1991
|
+
for decoder_layer in base_model.layers:
|
|
1992
|
+
if swiglu:
|
|
1993
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1994
|
+
if rms_norm:
|
|
1995
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1996
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1997
|
+
|
|
1998
|
+
|
|
1777
1999
|
def apply_liger_kernel_to_glm4(
|
|
1778
2000
|
rope: bool = False,
|
|
1779
2001
|
cross_entropy: bool = False,
|
|
@@ -2038,6 +2260,7 @@ def apply_liger_kernel_to_internvl(
|
|
|
2038
2260
|
cross_entropy: bool = False,
|
|
2039
2261
|
fused_linear_cross_entropy: bool = True,
|
|
2040
2262
|
rms_norm: bool = True,
|
|
2263
|
+
layer_norm: bool = True,
|
|
2041
2264
|
model: Optional[PreTrainedModel] = None,
|
|
2042
2265
|
**kwargs,
|
|
2043
2266
|
) -> None:
|
|
@@ -2048,37 +2271,60 @@ def apply_liger_kernel_to_internvl(
|
|
|
2048
2271
|
NOTE: InternVL is not available in transformers<4.52.1
|
|
2049
2272
|
|
|
2050
2273
|
Args:
|
|
2051
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2052
2274
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2053
2275
|
fused_linear_cross_entropy (bool):
|
|
2054
2276
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2055
2277
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2056
2278
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2057
2279
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2058
|
-
|
|
2280
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2059
2281
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2060
2282
|
loaded. Default is None.
|
|
2061
2283
|
"""
|
|
2062
2284
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2063
2285
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2064
2286
|
)
|
|
2287
|
+
import torch.nn as torch_nn
|
|
2065
2288
|
|
|
2066
2289
|
from transformers.models.internvl import modeling_internvl
|
|
2290
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2291
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2292
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2293
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2294
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2067
2295
|
|
|
2296
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2068
2297
|
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2298
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2299
|
+
|
|
2300
|
+
if layer_norm and model is None:
|
|
2301
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2069
2302
|
|
|
2070
2303
|
if cross_entropy:
|
|
2071
|
-
logger.
|
|
2072
|
-
|
|
2304
|
+
logger.info("Apply liger cross entropy")
|
|
2305
|
+
|
|
2306
|
+
from transformers.loss.loss_utils import nn
|
|
2307
|
+
|
|
2308
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2073
2309
|
if fused_linear_cross_entropy:
|
|
2074
2310
|
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2075
2311
|
if rms_norm:
|
|
2076
2312
|
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2077
2313
|
|
|
2078
2314
|
if model is not None:
|
|
2079
|
-
|
|
2315
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2316
|
+
# instance variables that reference already-instantiated modules
|
|
2317
|
+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
|
|
2318
|
+
# NOTE: language_model and visual properties can be accessed throught conditional class.
|
|
2319
|
+
text_model = model.language_model
|
|
2320
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2321
|
+
else:
|
|
2322
|
+
raise TypeError(
|
|
2323
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2324
|
+
)
|
|
2325
|
+
|
|
2326
|
+
text_model_name = model.config.text_config.model_type
|
|
2080
2327
|
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2081
|
-
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
2082
2328
|
|
|
2083
2329
|
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2084
2330
|
if text_liger_fn:
|
|
@@ -2091,25 +2337,33 @@ def apply_liger_kernel_to_internvl(
|
|
|
2091
2337
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2092
2338
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2093
2339
|
)
|
|
2094
|
-
text_kwargs["model"] =
|
|
2340
|
+
text_kwargs["model"] = text_model
|
|
2095
2341
|
text_liger_fn(**text_kwargs)
|
|
2096
2342
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2097
2343
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2098
2344
|
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
|
|
2345
|
+
# Patch vision model RMSNorm layers
|
|
2346
|
+
if rms_norm:
|
|
2347
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2348
|
+
encoder_layer: InternVLVisionLayer
|
|
2349
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2350
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2351
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2352
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2103
2353
|
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2108
|
-
)
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
|
|
2112
|
-
|
|
2354
|
+
# Patch vision model LayerNorm layers
|
|
2355
|
+
if layer_norm:
|
|
2356
|
+
# Patch layernorm
|
|
2357
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2358
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2359
|
+
|
|
2360
|
+
# Patch encoder layers
|
|
2361
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2362
|
+
encoder_layer: InternVLVisionLayer
|
|
2363
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2364
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2365
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2366
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2113
2367
|
|
|
2114
2368
|
|
|
2115
2369
|
def apply_liger_kernel_to_smolvlm(
|
|
@@ -2372,6 +2626,123 @@ def apply_liger_kernel_to_qwen3_next(
|
|
|
2372
2626
|
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2373
2627
|
|
|
2374
2628
|
|
|
2629
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2630
|
+
rope: bool = True,
|
|
2631
|
+
cross_entropy: bool = False,
|
|
2632
|
+
fused_linear_cross_entropy: bool = True,
|
|
2633
|
+
rms_norm: bool = True,
|
|
2634
|
+
swiglu: bool = True,
|
|
2635
|
+
model: PreTrainedModel = None,
|
|
2636
|
+
) -> None:
|
|
2637
|
+
"""
|
|
2638
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2639
|
+
"""
|
|
2640
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2641
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2642
|
+
)
|
|
2643
|
+
|
|
2644
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2645
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2646
|
+
|
|
2647
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2648
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2649
|
+
|
|
2650
|
+
if rope:
|
|
2651
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2652
|
+
|
|
2653
|
+
if rms_norm:
|
|
2654
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2655
|
+
|
|
2656
|
+
if cross_entropy:
|
|
2657
|
+
from transformers.loss.loss_utils import nn
|
|
2658
|
+
|
|
2659
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2660
|
+
|
|
2661
|
+
if fused_linear_cross_entropy:
|
|
2662
|
+
if model is not None:
|
|
2663
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2664
|
+
else:
|
|
2665
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2666
|
+
|
|
2667
|
+
if swiglu:
|
|
2668
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2669
|
+
|
|
2670
|
+
if model is not None:
|
|
2671
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2672
|
+
# instance variables that reference already-instantiated modules
|
|
2673
|
+
|
|
2674
|
+
# get the base model from the model instance
|
|
2675
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
2676
|
+
|
|
2677
|
+
if rms_norm:
|
|
2678
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2679
|
+
for decoder_layer in base_model.layers:
|
|
2680
|
+
if swiglu:
|
|
2681
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2682
|
+
if rms_norm:
|
|
2683
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2684
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2685
|
+
|
|
2686
|
+
|
|
2687
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2688
|
+
rope: bool = True,
|
|
2689
|
+
cross_entropy: bool = False,
|
|
2690
|
+
fused_linear_cross_entropy: bool = True,
|
|
2691
|
+
rms_norm: bool = True,
|
|
2692
|
+
swiglu: bool = True,
|
|
2693
|
+
model: PreTrainedModel = None,
|
|
2694
|
+
) -> None:
|
|
2695
|
+
"""
|
|
2696
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2697
|
+
"""
|
|
2698
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2699
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2703
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2704
|
+
|
|
2705
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2706
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2707
|
+
|
|
2708
|
+
if rope:
|
|
2709
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2710
|
+
|
|
2711
|
+
if rms_norm:
|
|
2712
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2713
|
+
|
|
2714
|
+
if cross_entropy:
|
|
2715
|
+
from transformers.loss.loss_utils import nn
|
|
2716
|
+
|
|
2717
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2718
|
+
|
|
2719
|
+
if fused_linear_cross_entropy:
|
|
2720
|
+
if model is not None:
|
|
2721
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2722
|
+
else:
|
|
2723
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2724
|
+
|
|
2725
|
+
if swiglu:
|
|
2726
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2727
|
+
|
|
2728
|
+
if model is not None:
|
|
2729
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2730
|
+
# instance variables that reference already-instantiated modules
|
|
2731
|
+
|
|
2732
|
+
# get the base model from the model instance
|
|
2733
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2734
|
+
|
|
2735
|
+
if rms_norm:
|
|
2736
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2737
|
+
for decoder_layer in base_model.layers:
|
|
2738
|
+
if swiglu:
|
|
2739
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2740
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
2741
|
+
if rms_norm:
|
|
2742
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2743
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2744
|
+
|
|
2745
|
+
|
|
2375
2746
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
2376
2747
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
2377
2748
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -2392,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2392
2763
|
"mistral": apply_liger_kernel_to_mistral,
|
|
2393
2764
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
2394
2765
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2766
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
2395
2767
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2396
2768
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2397
2769
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
@@ -2400,11 +2772,17 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2400
2772
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2401
2773
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2402
2774
|
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2775
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2776
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2777
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2778
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2403
2779
|
"smollm3": apply_liger_kernel_to_smollm3,
|
|
2404
2780
|
"phi3": apply_liger_kernel_to_phi3,
|
|
2405
2781
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2406
2782
|
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2407
2783
|
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2784
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2785
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2408
2786
|
}
|
|
2409
2787
|
|
|
2410
2788
|
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
1
6
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
2
7
|
|
|
3
8
|
|
|
@@ -18,3 +23,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
18
23
|
"""
|
|
19
24
|
|
|
20
25
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def liger_rotary_pos_emb_with_cast(
|
|
29
|
+
q: torch.Tensor,
|
|
30
|
+
k: torch.Tensor,
|
|
31
|
+
cos: torch.Tensor,
|
|
32
|
+
sin: torch.Tensor,
|
|
33
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
34
|
+
unsqueeze_dim: int = 1,
|
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
36
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
37
|
+
|
|
38
|
+
q32 = q.to(torch.float32)
|
|
39
|
+
k32 = k.to(torch.float32)
|
|
40
|
+
cos32 = cos.to(torch.float32)
|
|
41
|
+
sin32 = sin.to(torch.float32)
|
|
42
|
+
|
|
43
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
|
44
|
+
return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def liger_rotary_pos_emb_with_cast_and_leading_batch(
|
|
48
|
+
q: torch.Tensor,
|
|
49
|
+
k: torch.Tensor,
|
|
50
|
+
cos: torch.Tensor,
|
|
51
|
+
sin: torch.Tensor,
|
|
52
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
unsqueeze_dim: int = 1,
|
|
54
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
55
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
56
|
+
|
|
57
|
+
q32 = q.to(torch.float32).unsqueeze(0)
|
|
58
|
+
k32 = k.to(torch.float32).unsqueeze(0)
|
|
59
|
+
cos32 = cos.to(torch.float32).unsqueeze(0)
|
|
60
|
+
sin32 = sin.to(torch.float32).unsqueeze(0)
|
|
61
|
+
|
|
62
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
|
63
|
+
return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
|
|
@@ -77,3 +77,20 @@ class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
|
77
77
|
|
|
78
78
|
def forward(self, x):
|
|
79
79
|
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LigerHunyuanV1SwiGLUMLP(nn.Module):
|
|
83
|
+
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.config = config
|
|
86
|
+
self.hidden_size = config.hidden_size
|
|
87
|
+
self.intermediate_size = config.intermediate_size
|
|
88
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
89
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
90
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
91
|
+
self.layer_idx = layer_idx
|
|
92
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
93
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
94
|
+
|
|
95
|
+
def forward(self, x):
|
|
96
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|