liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +11 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +8 -12
- liger_kernel/transformers/model/gemma2.py +8 -10
- liger_kernel/transformers/model/gemma3.py +3 -9
- liger_kernel/transformers/model/glm4.py +119 -0
- liger_kernel/transformers/model/llama.py +64 -15
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +8 -10
- liger_kernel/transformers/model/mixtral.py +8 -12
- liger_kernel/transformers/model/mllama.py +8 -11
- liger_kernel/transformers/model/olmo2.py +8 -10
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +8 -12
- liger_kernel/transformers/model/qwen2.py +8 -12
- liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
- liger_kernel/transformers/model/qwen2_vl.py +3 -7
- liger_kernel/transformers/model/qwen3.py +112 -0
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +243 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -35,6 +35,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
|
35
35
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
36
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
37
37
|
|
|
38
|
+
try:
|
|
39
|
+
import peft
|
|
40
|
+
|
|
41
|
+
PEFT_AVAILABLE = True
|
|
42
|
+
except ImportError:
|
|
43
|
+
PEFT_AVAILABLE = False
|
|
44
|
+
|
|
38
45
|
transformer_version = version.parse(transformers.__version__)
|
|
39
46
|
|
|
40
47
|
logger = logging.getLogger(__name__)
|
|
@@ -48,22 +55,68 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
51
|
-
module
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
59
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
60
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
61
|
+
module.modules_to_save.default.offset = offset
|
|
62
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
63
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
64
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
65
|
+
)
|
|
66
|
+
module.modules_to_save.default.in_place = in_place
|
|
67
|
+
module.original_module.offset = offset
|
|
68
|
+
module.original_module.casting_mode = casting_mode
|
|
69
|
+
module.original_module.variance_epsilon = (
|
|
70
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
71
|
+
)
|
|
72
|
+
module.original_module.in_place = in_place
|
|
73
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
74
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
75
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
76
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
77
|
+
module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
|
|
78
|
+
module.original_module.__class__.__name__ = LigerRMSNorm.__name__
|
|
79
|
+
else:
|
|
80
|
+
module.offset = offset
|
|
81
|
+
module.casting_mode = casting_mode
|
|
82
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
83
|
+
module.in_place = in_place
|
|
84
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
85
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
86
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
58
87
|
|
|
59
88
|
|
|
60
89
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
90
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
91
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
92
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
93
|
+
module.hidden_size = module.normalized_shape
|
|
94
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
95
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
96
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
97
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
98
|
+
)
|
|
99
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
100
|
+
module, "normalized_shape", None
|
|
101
|
+
)
|
|
102
|
+
module.original_module.variance_epsilon = (
|
|
103
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
104
|
+
)
|
|
105
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
106
|
+
module, "normalized_shape", None
|
|
107
|
+
)
|
|
108
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
109
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
110
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
111
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
112
|
+
module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
|
|
113
|
+
module.original_module.__class__.__name__ = LigerLayerNorm.__name__
|
|
114
|
+
else:
|
|
115
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
116
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
117
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
118
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
119
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
67
120
|
|
|
68
121
|
|
|
69
122
|
def _patch_swiglu_module(module, liger_module):
|
|
@@ -1048,6 +1101,115 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1048
1101
|
print("Applied Liger kernels to Qwen2")
|
|
1049
1102
|
|
|
1050
1103
|
|
|
1104
|
+
def apply_liger_kernel_to_qwen3(
|
|
1105
|
+
rope: bool = True,
|
|
1106
|
+
cross_entropy: bool = False,
|
|
1107
|
+
fused_linear_cross_entropy: bool = True,
|
|
1108
|
+
rms_norm: bool = True,
|
|
1109
|
+
swiglu: bool = True,
|
|
1110
|
+
model: PreTrainedModel = None,
|
|
1111
|
+
) -> None:
|
|
1112
|
+
"""
|
|
1113
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1114
|
+
"""
|
|
1115
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1116
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
from transformers.models.qwen3 import modeling_qwen3
|
|
1120
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
|
|
1121
|
+
|
|
1122
|
+
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
|
|
1123
|
+
|
|
1124
|
+
if rope:
|
|
1125
|
+
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1126
|
+
|
|
1127
|
+
if rms_norm:
|
|
1128
|
+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
1129
|
+
|
|
1130
|
+
if cross_entropy:
|
|
1131
|
+
from transformers.loss.loss_utils import nn
|
|
1132
|
+
|
|
1133
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1134
|
+
|
|
1135
|
+
if fused_linear_cross_entropy:
|
|
1136
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1137
|
+
|
|
1138
|
+
if swiglu:
|
|
1139
|
+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
1140
|
+
|
|
1141
|
+
if model is not None:
|
|
1142
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1143
|
+
# instance variables that reference already-instantiated modules
|
|
1144
|
+
|
|
1145
|
+
# get the base model from the model instance
|
|
1146
|
+
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
|
|
1147
|
+
|
|
1148
|
+
if rms_norm:
|
|
1149
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1150
|
+
for decoder_layer in base_model.layers:
|
|
1151
|
+
if swiglu:
|
|
1152
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1153
|
+
if rms_norm:
|
|
1154
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1155
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1159
|
+
rope: bool = True,
|
|
1160
|
+
cross_entropy: bool = False,
|
|
1161
|
+
fused_linear_cross_entropy: bool = True,
|
|
1162
|
+
rms_norm: bool = True,
|
|
1163
|
+
swiglu: bool = True,
|
|
1164
|
+
model: PreTrainedModel = None,
|
|
1165
|
+
) -> None:
|
|
1166
|
+
"""
|
|
1167
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1168
|
+
"""
|
|
1169
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1170
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1174
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1175
|
+
|
|
1176
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1177
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1178
|
+
|
|
1179
|
+
if rope:
|
|
1180
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1181
|
+
|
|
1182
|
+
if rms_norm:
|
|
1183
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1184
|
+
|
|
1185
|
+
if cross_entropy:
|
|
1186
|
+
from transformers.loss.loss_utils import nn
|
|
1187
|
+
|
|
1188
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1189
|
+
|
|
1190
|
+
if fused_linear_cross_entropy:
|
|
1191
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1192
|
+
|
|
1193
|
+
if swiglu:
|
|
1194
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1195
|
+
|
|
1196
|
+
if model is not None:
|
|
1197
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1198
|
+
# instance variables that reference already-instantiated modules
|
|
1199
|
+
|
|
1200
|
+
# get the base model from the model instance
|
|
1201
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1202
|
+
|
|
1203
|
+
if rms_norm:
|
|
1204
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1205
|
+
for decoder_layer in base_model.layers:
|
|
1206
|
+
if swiglu:
|
|
1207
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
1208
|
+
if rms_norm:
|
|
1209
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1210
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1211
|
+
|
|
1212
|
+
|
|
1051
1213
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1052
1214
|
rope: bool = True,
|
|
1053
1215
|
cross_entropy: bool = False,
|
|
@@ -1319,12 +1481,78 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1319
1481
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1320
1482
|
|
|
1321
1483
|
|
|
1484
|
+
def apply_liger_kernel_to_glm4(
|
|
1485
|
+
rope: bool = False,
|
|
1486
|
+
cross_entropy: bool = False,
|
|
1487
|
+
fused_linear_cross_entropy: bool = True,
|
|
1488
|
+
rms_norm: bool = True,
|
|
1489
|
+
swiglu: bool = True,
|
|
1490
|
+
model: PreTrainedModel = None,
|
|
1491
|
+
) -> None:
|
|
1492
|
+
"""
|
|
1493
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
1494
|
+
|
|
1495
|
+
Args:
|
|
1496
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
1497
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1498
|
+
fused_linear_cross_entropy (bool):
|
|
1499
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1500
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1501
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1502
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1503
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
1504
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1505
|
+
loaded. Default is None.
|
|
1506
|
+
"""
|
|
1507
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1508
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
from transformers.models.glm4 import modeling_glm4
|
|
1512
|
+
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
1513
|
+
|
|
1514
|
+
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
1515
|
+
|
|
1516
|
+
if rope:
|
|
1517
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
1518
|
+
if rms_norm:
|
|
1519
|
+
modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
|
|
1520
|
+
if swiglu:
|
|
1521
|
+
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
1522
|
+
if cross_entropy:
|
|
1523
|
+
from transformers.loss.loss_utils import nn
|
|
1524
|
+
|
|
1525
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1526
|
+
if fused_linear_cross_entropy:
|
|
1527
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1528
|
+
|
|
1529
|
+
if model is not None:
|
|
1530
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1531
|
+
# instance variables that reference already-instantiated modules
|
|
1532
|
+
|
|
1533
|
+
# get the base model from the model instance
|
|
1534
|
+
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
|
|
1535
|
+
|
|
1536
|
+
if rms_norm:
|
|
1537
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
1538
|
+
|
|
1539
|
+
for decoder_layer in base_model.layers:
|
|
1540
|
+
if swiglu:
|
|
1541
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1542
|
+
if rms_norm:
|
|
1543
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
|
|
1544
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1545
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
|
|
1546
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
1547
|
+
|
|
1548
|
+
|
|
1322
1549
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
1323
1550
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1324
1551
|
"gemma": apply_liger_kernel_to_gemma,
|
|
1325
1552
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
1326
1553
|
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1327
1554
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1555
|
+
"glm4": apply_liger_kernel_to_glm4,
|
|
1328
1556
|
"llama": apply_liger_kernel_to_llama,
|
|
1329
1557
|
"llava": apply_liger_kernel_to_llava,
|
|
1330
1558
|
"granite": apply_liger_kernel_to_granite,
|
|
@@ -1334,6 +1562,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1334
1562
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
1335
1563
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
1336
1564
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
1565
|
+
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1566
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1337
1567
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1338
1568
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1339
1569
|
"phi3": apply_liger_kernel_to_phi3,
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSparsemax(nn.Module):
|
|
8
|
+
def __init__(self, dim: int = -1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.dim = dim
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
|
14
|
+
|
|
15
|
+
def extra_repr(self) -> str:
|
|
16
|
+
return f"dim={self.dim}"
|
|
@@ -56,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
56
56
|
up_states = self.gate_up_proj(x)
|
|
57
57
|
gate, up_states = up_states.chunk(2, dim=-1)
|
|
58
58
|
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
|
|
64
|
+
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config, intermediate_size=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.hidden_size = config.hidden_size
|
|
71
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
72
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
73
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
74
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
75
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
76
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
from typing import Callable
|
|
3
1
|
from typing import Dict
|
|
4
2
|
from typing import List
|
|
5
3
|
from typing import Literal
|
|
@@ -13,57 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel
|
|
|
13
11
|
from trl.trainer import ORPOTrainer
|
|
14
12
|
|
|
15
13
|
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class _FSDPForwardRedirection:
|
|
19
|
-
"""
|
|
20
|
-
Modified based on
|
|
21
|
-
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
22
|
-
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
23
|
-
post-forward can be properly executed around the method call.
|
|
24
|
-
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
25
|
-
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
26
|
-
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
27
|
-
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
|
|
28
|
-
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
29
|
-
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
30
|
-
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __call__(
|
|
34
|
-
self,
|
|
35
|
-
wrapper_module: FullyShardedDataParallel,
|
|
36
|
-
method: Callable,
|
|
37
|
-
*args: Any,
|
|
38
|
-
**kwargs: Any,
|
|
39
|
-
):
|
|
40
|
-
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
41
|
-
Args:
|
|
42
|
-
wrapper_module: The module that has `original_module` wrapped.
|
|
43
|
-
original_module: The module that was wrapped inside `wrapper_module`.
|
|
44
|
-
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
45
|
-
redirected through the `wrapper_module`'s `forward` method.
|
|
46
|
-
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
47
|
-
`forward` method instead.
|
|
48
|
-
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
49
|
-
`forward` method instead.
|
|
50
|
-
"""
|
|
51
|
-
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
52
|
-
original_module = wrapper_module._fsdp_wrapped_module
|
|
53
|
-
original_forward = original_module.forward
|
|
54
|
-
|
|
55
|
-
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
56
|
-
# Unpatch ourselves immediately before calling the method `method_name`
|
|
57
|
-
# because itself may want to call the real `forward`
|
|
58
|
-
original_module.forward = original_forward # type: ignore[method-assign]
|
|
59
|
-
# Call the actual method e.g. `.training_step(...)`
|
|
60
|
-
out = method(*_args, **_kwargs)
|
|
61
|
-
return out
|
|
62
|
-
|
|
63
|
-
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
64
|
-
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
65
|
-
wrapper_output = wrapper_module(*args, **kwargs)
|
|
66
|
-
return wrapper_output
|
|
14
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
67
15
|
|
|
68
16
|
|
|
69
17
|
class LigerORPOTrainer(ORPOTrainer):
|
liger_kernel/utils.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import peft # noqa: F401
|
|
3
|
+
|
|
4
|
+
PEFT_AVAILABLE = True
|
|
5
|
+
except ImportError:
|
|
6
|
+
PEFT_AVAILABLE = False
|
|
7
|
+
|
|
1
8
|
import torch
|
|
2
9
|
|
|
3
10
|
|
|
11
|
+
def is_peft_available():
|
|
12
|
+
return PEFT_AVAILABLE
|
|
13
|
+
|
|
14
|
+
|
|
4
15
|
def infer_device():
|
|
5
16
|
"""
|
|
6
17
|
Get current device name based on available devices
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.10
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -59,7 +59,6 @@ Dynamic: requires-dist
|
|
|
59
59
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
|
60
60
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
|
61
61
|
<th style="padding: 10px;">Discord</th>
|
|
62
|
-
<th style="padding: 10px;">Build</th>
|
|
63
62
|
</tr>
|
|
64
63
|
<tr>
|
|
65
64
|
<td style="padding: 10px;">
|
|
@@ -87,23 +86,6 @@ Dynamic: requires-dist
|
|
|
87
86
|
<img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
|
|
88
87
|
</a>
|
|
89
88
|
</td>
|
|
90
|
-
<td style="padding: 10px;">
|
|
91
|
-
<div style="display: block;">
|
|
92
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
93
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
94
|
-
</a>
|
|
95
|
-
</div>
|
|
96
|
-
<div style="display: block;">
|
|
97
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
98
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
99
|
-
</a>
|
|
100
|
-
</div>
|
|
101
|
-
<div style="display: block;">
|
|
102
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
103
|
-
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
104
|
-
</a>
|
|
105
|
-
</div>
|
|
106
|
-
</td>
|
|
107
89
|
</tr>
|
|
108
90
|
</table>
|
|
109
91
|
|
|
@@ -320,9 +302,12 @@ loss.backward()
|
|
|
320
302
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
321
303
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
322
304
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
305
|
+
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
306
|
+
| Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
323
307
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
324
308
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
325
309
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
310
|
+
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
326
311
|
|
|
327
312
|
|
|
328
313
|
## Low-level APIs
|
|
@@ -340,7 +325,8 @@ loss.backward()
|
|
|
340
325
|
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
|
|
341
326
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
342
327
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
343
|
-
| Fused Linear CrossEntropy
|
|
328
|
+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
329
|
+
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
|
|
344
330
|
|
|
345
331
|
|
|
346
332
|
### Alignment Kernels
|
|
@@ -388,6 +374,36 @@ loss.backward()
|
|
|
388
374
|
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
|
|
389
375
|
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
|
|
390
376
|
|
|
377
|
+
|
|
378
|
+
## CI status
|
|
379
|
+
|
|
380
|
+
<table style="width: 100%; text-align: center; border-collapse: collapse;">
|
|
381
|
+
<tr>
|
|
382
|
+
<th style="padding: 10px;">Build</th>
|
|
383
|
+
</tr>
|
|
384
|
+
<tr>
|
|
385
|
+
<td style="padding: 10px;">
|
|
386
|
+
<div style="display: block;">
|
|
387
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
|
388
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
389
|
+
</a>
|
|
390
|
+
</div>
|
|
391
|
+
<div style="display: block;">
|
|
392
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
393
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
394
|
+
</a>
|
|
395
|
+
</div>
|
|
396
|
+
<div style="display: block;">
|
|
397
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
|
398
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
399
|
+
</a>
|
|
400
|
+
</div>
|
|
401
|
+
</td>
|
|
402
|
+
</tr>
|
|
403
|
+
</table>
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
391
407
|
## Contact
|
|
392
408
|
|
|
393
409
|
- For issues, create a Github ticket in this repository
|