liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +8 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/cross_entropy.py +4 -1
  5. liger_kernel/ops/dyt.py +113 -179
  6. liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  7. liger_kernel/ops/grpo_loss.py +310 -0
  8. liger_kernel/ops/sparsemax.py +167 -0
  9. liger_kernel/transformers/__init__.py +11 -0
  10. liger_kernel/transformers/dyt.py +5 -3
  11. liger_kernel/transformers/fsdp.py +55 -0
  12. liger_kernel/transformers/functional.py +8 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  14. liger_kernel/transformers/grpo_loss.py +98 -0
  15. liger_kernel/transformers/model/gemma.py +8 -12
  16. liger_kernel/transformers/model/gemma2.py +8 -10
  17. liger_kernel/transformers/model/gemma3.py +3 -9
  18. liger_kernel/transformers/model/glm4.py +119 -0
  19. liger_kernel/transformers/model/llama.py +64 -15
  20. liger_kernel/transformers/model/llava.py +0 -8
  21. liger_kernel/transformers/model/mistral.py +8 -10
  22. liger_kernel/transformers/model/mixtral.py +8 -12
  23. liger_kernel/transformers/model/mllama.py +8 -11
  24. liger_kernel/transformers/model/olmo2.py +8 -10
  25. liger_kernel/transformers/model/paligemma.py +0 -8
  26. liger_kernel/transformers/model/phi3.py +8 -12
  27. liger_kernel/transformers/model/qwen2.py +8 -12
  28. liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
  29. liger_kernel/transformers/model/qwen2_vl.py +3 -7
  30. liger_kernel/transformers/model/qwen3.py +112 -0
  31. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  32. liger_kernel/transformers/monkey_patch.py +243 -13
  33. liger_kernel/transformers/sparsemax.py +16 -0
  34. liger_kernel/transformers/swiglu.py +21 -0
  35. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  36. liger_kernel/utils.py +11 -0
  37. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
  38. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
  39. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  40. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  41. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  42. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
35
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
36
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
37
 
38
+ try:
39
+ import peft
40
+
41
+ PEFT_AVAILABLE = True
42
+ except ImportError:
43
+ PEFT_AVAILABLE = False
44
+
38
45
  transformer_version = version.parse(transformers.__version__)
39
46
 
40
47
  logger = logging.getLogger(__name__)
@@ -48,22 +55,68 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
48
55
 
49
56
 
50
57
  def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
58
+ # Check if the module is a PEFT ModulesToSaveWrapper
59
+ # If it is, we need to patch the modules_to_save.default and original_modules
60
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
61
+ module.modules_to_save.default.offset = offset
62
+ module.modules_to_save.default.casting_mode = casting_mode
63
+ module.modules_to_save.default.variance_epsilon = (
64
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
+ )
66
+ module.modules_to_save.default.in_place = in_place
67
+ module.original_module.offset = offset
68
+ module.original_module.casting_mode = casting_mode
69
+ module.original_module.variance_epsilon = (
70
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
+ )
72
+ module.original_module.in_place = in_place
73
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
76
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
77
+ module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
78
+ module.original_module.__class__.__name__ = LigerRMSNorm.__name__
79
+ else:
80
+ module.offset = offset
81
+ module.casting_mode = casting_mode
82
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
+ module.in_place = in_place
84
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
+ module.__class__.__name__ = LigerRMSNorm.__name__
58
87
 
59
88
 
60
89
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
90
+ # Check if the module is a PEFT ModulesToSaveWrapper
91
+ # If it is, we need to patch the modules_to_save.default and original_modules
92
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
93
+ module.hidden_size = module.normalized_shape
94
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
95
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
96
+ module.modules_to_save.default.variance_epsilon = (
97
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
98
+ )
99
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
100
+ module, "normalized_shape", None
101
+ )
102
+ module.original_module.variance_epsilon = (
103
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
104
+ )
105
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
106
+ module, "normalized_shape", None
107
+ )
108
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
109
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
110
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
111
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
112
+ module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
113
+ module.original_module.__class__.__name__ = LigerLayerNorm.__name__
114
+ else:
115
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
116
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
117
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
118
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
119
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
120
 
68
121
 
69
122
  def _patch_swiglu_module(module, liger_module):
@@ -1048,6 +1101,115 @@ def apply_liger_kernel_to_qwen2(
1048
1101
  print("Applied Liger kernels to Qwen2")
1049
1102
 
1050
1103
 
1104
+ def apply_liger_kernel_to_qwen3(
1105
+ rope: bool = True,
1106
+ cross_entropy: bool = False,
1107
+ fused_linear_cross_entropy: bool = True,
1108
+ rms_norm: bool = True,
1109
+ swiglu: bool = True,
1110
+ model: PreTrainedModel = None,
1111
+ ) -> None:
1112
+ """
1113
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1114
+ """
1115
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1116
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1117
+ )
1118
+
1119
+ from transformers.models.qwen3 import modeling_qwen3
1120
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1121
+
1122
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1123
+
1124
+ if rope:
1125
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1126
+
1127
+ if rms_norm:
1128
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1129
+
1130
+ if cross_entropy:
1131
+ from transformers.loss.loss_utils import nn
1132
+
1133
+ nn.functional.cross_entropy = liger_cross_entropy
1134
+
1135
+ if fused_linear_cross_entropy:
1136
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1137
+
1138
+ if swiglu:
1139
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1140
+
1141
+ if model is not None:
1142
+ # The model instance already exists, so we need to additionally patch the
1143
+ # instance variables that reference already-instantiated modules
1144
+
1145
+ # get the base model from the model instance
1146
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1147
+
1148
+ if rms_norm:
1149
+ _patch_rms_norm_module(base_model.norm)
1150
+ for decoder_layer in base_model.layers:
1151
+ if swiglu:
1152
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1153
+ if rms_norm:
1154
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1155
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1156
+
1157
+
1158
+ def apply_liger_kernel_to_qwen3_moe(
1159
+ rope: bool = True,
1160
+ cross_entropy: bool = False,
1161
+ fused_linear_cross_entropy: bool = True,
1162
+ rms_norm: bool = True,
1163
+ swiglu: bool = True,
1164
+ model: PreTrainedModel = None,
1165
+ ) -> None:
1166
+ """
1167
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1168
+ """
1169
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1170
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1171
+ )
1172
+
1173
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1174
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1175
+
1176
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1177
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1178
+
1179
+ if rope:
1180
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1181
+
1182
+ if rms_norm:
1183
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1184
+
1185
+ if cross_entropy:
1186
+ from transformers.loss.loss_utils import nn
1187
+
1188
+ nn.functional.cross_entropy = liger_cross_entropy
1189
+
1190
+ if fused_linear_cross_entropy:
1191
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1192
+
1193
+ if swiglu:
1194
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1195
+
1196
+ if model is not None:
1197
+ # The model instance already exists, so we need to additionally patch the
1198
+ # instance variables that reference already-instantiated modules
1199
+
1200
+ # get the base model from the model instance
1201
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1202
+
1203
+ if rms_norm:
1204
+ _patch_rms_norm_module(base_model.norm)
1205
+ for decoder_layer in base_model.layers:
1206
+ if swiglu:
1207
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
1208
+ if rms_norm:
1209
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1210
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1211
+
1212
+
1051
1213
  def apply_liger_kernel_to_qwen2_vl(
1052
1214
  rope: bool = True,
1053
1215
  cross_entropy: bool = False,
@@ -1319,12 +1481,78 @@ def apply_liger_kernel_to_olmo2(
1319
1481
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1320
1482
 
1321
1483
 
1484
+ def apply_liger_kernel_to_glm4(
1485
+ rope: bool = False,
1486
+ cross_entropy: bool = False,
1487
+ fused_linear_cross_entropy: bool = True,
1488
+ rms_norm: bool = True,
1489
+ swiglu: bool = True,
1490
+ model: PreTrainedModel = None,
1491
+ ) -> None:
1492
+ """
1493
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1494
+
1495
+ Args:
1496
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1497
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1498
+ fused_linear_cross_entropy (bool):
1499
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1500
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1501
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1502
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1503
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1504
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1505
+ loaded. Default is None.
1506
+ """
1507
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1508
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1509
+ )
1510
+
1511
+ from transformers.models.glm4 import modeling_glm4
1512
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
1513
+
1514
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1515
+
1516
+ if rope:
1517
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1518
+ if rms_norm:
1519
+ modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1520
+ if swiglu:
1521
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1522
+ if cross_entropy:
1523
+ from transformers.loss.loss_utils import nn
1524
+
1525
+ nn.functional.cross_entropy = liger_cross_entropy
1526
+ if fused_linear_cross_entropy:
1527
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1528
+
1529
+ if model is not None:
1530
+ # The model instance already exists, so we need to additionally patch the
1531
+ # instance variables that reference already-instantiated modules
1532
+
1533
+ # get the base model from the model instance
1534
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
1535
+
1536
+ if rms_norm:
1537
+ _patch_rms_norm_module(base_model.norm, in_place=False)
1538
+
1539
+ for decoder_layer in base_model.layers:
1540
+ if swiglu:
1541
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1542
+ if rms_norm:
1543
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
1544
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1545
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
1546
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1547
+
1548
+
1322
1549
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1323
1550
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1324
1551
  "gemma": apply_liger_kernel_to_gemma,
1325
1552
  "gemma2": apply_liger_kernel_to_gemma2,
1326
1553
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
1554
  "gemma3": apply_liger_kernel_to_gemma3,
1555
+ "glm4": apply_liger_kernel_to_glm4,
1328
1556
  "llama": apply_liger_kernel_to_llama,
1329
1557
  "llava": apply_liger_kernel_to_llava,
1330
1558
  "granite": apply_liger_kernel_to_granite,
@@ -1334,6 +1562,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1334
1562
  "mixtral": apply_liger_kernel_to_mixtral,
1335
1563
  "olmo2": apply_liger_kernel_to_olmo2,
1336
1564
  "qwen2": apply_liger_kernel_to_qwen2,
1565
+ "qwen3": apply_liger_kernel_to_qwen3,
1566
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1337
1567
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1338
1568
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1339
1569
  "phi3": apply_liger_kernel_to_phi3,
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"
@@ -56,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
56
56
  up_states = self.gate_up_proj(x)
57
57
  gate, up_states = up_states.chunk(2, dim=-1)
58
58
  return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
59
+
60
+
61
+ class LigerQwen3MoeSwiGLUMLP(nn.Module):
62
+ """
63
+ Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
64
+ https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
65
+ """
66
+
67
+ def __init__(self, config, intermediate_size=None):
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
72
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
+ if config.hidden_act not in ["silu", "swish"]:
76
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
77
+
78
+ def forward(self, x):
79
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -1,5 +1,3 @@
1
- from typing import Any
2
- from typing import Callable
3
1
  from typing import Dict
4
2
  from typing import List
5
3
  from typing import Literal
@@ -13,57 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel
13
11
  from trl.trainer import ORPOTrainer
14
12
 
15
13
  from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
16
-
17
-
18
- class _FSDPForwardRedirection:
19
- """
20
- Modified based on
21
- https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
22
- Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
23
- post-forward can be properly executed around the method call.
24
- This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
25
- the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
26
- GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
27
- will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
28
- the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
29
- its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
30
- the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
31
- """
32
-
33
- def __call__(
34
- self,
35
- wrapper_module: FullyShardedDataParallel,
36
- method: Callable,
37
- *args: Any,
38
- **kwargs: Any,
39
- ):
40
- """Reroutes a method call through the `wrapper_module`'s `forward` method.
41
- Args:
42
- wrapper_module: The module that has `original_module` wrapped.
43
- original_module: The module that was wrapped inside `wrapper_module`.
44
- method_name: The name of the method that should be called on the `original_module` after inputs get
45
- redirected through the `wrapper_module`'s `forward` method.
46
- *args: The positional arguments to the method `method_name`. They will get passed to a patched
47
- `forward` method instead.
48
- **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
49
- `forward` method instead.
50
- """
51
- assert isinstance(wrapper_module, FullyShardedDataParallel)
52
- original_module = wrapper_module._fsdp_wrapped_module
53
- original_forward = original_module.forward
54
-
55
- def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
56
- # Unpatch ourselves immediately before calling the method `method_name`
57
- # because itself may want to call the real `forward`
58
- original_module.forward = original_forward # type: ignore[method-assign]
59
- # Call the actual method e.g. `.training_step(...)`
60
- out = method(*_args, **_kwargs)
61
- return out
62
-
63
- # Patch the original_module's forward so we can redirect the arguments back to the real method
64
- original_module.forward = wrapped_forward # type: ignore[method-assign]
65
- wrapper_output = wrapper_module(*args, **kwargs)
66
- return wrapper_output
14
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
67
15
 
68
16
 
69
17
  class LigerORPOTrainer(ORPOTrainer):
liger_kernel/utils.py CHANGED
@@ -1,6 +1,17 @@
1
+ try:
2
+ import peft # noqa: F401
3
+
4
+ PEFT_AVAILABLE = True
5
+ except ImportError:
6
+ PEFT_AVAILABLE = False
7
+
1
8
  import torch
2
9
 
3
10
 
11
+ def is_peft_available():
12
+ return PEFT_AVAILABLE
13
+
14
+
4
15
  def infer_device():
5
16
  """
6
17
  Get current device name based on available devices
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.8
3
+ Version: 0.5.10
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -59,7 +59,6 @@ Dynamic: requires-dist
59
59
  <th style="padding: 10px;" colspan="2">Stable</th>
60
60
  <th style="padding: 10px;" colspan="2">Nightly</th>
61
61
  <th style="padding: 10px;">Discord</th>
62
- <th style="padding: 10px;">Build</th>
63
62
  </tr>
64
63
  <tr>
65
64
  <td style="padding: 10px;">
@@ -87,23 +86,6 @@ Dynamic: requires-dist
87
86
  <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
88
87
  </a>
89
88
  </td>
90
- <td style="padding: 10px;">
91
- <div style="display: block;">
92
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
93
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
94
- </a>
95
- </div>
96
- <div style="display: block;">
97
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
98
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
99
- </a>
100
- </div>
101
- <div style="display: block;">
102
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
103
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
104
- </a>
105
- </div>
106
- </td>
107
89
  </tr>
108
90
  </table>
109
91
 
@@ -320,9 +302,12 @@ loss.backward()
320
302
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
303
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
322
304
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
305
+ | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
306
+ | Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
307
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
324
308
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
325
309
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
310
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
326
311
 
327
312
 
328
313
  ## Low-level APIs
@@ -340,7 +325,8 @@ loss.backward()
340
325
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
341
326
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
342
327
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
343
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
328
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
329
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
344
330
 
345
331
 
346
332
  ### Alignment Kernels
@@ -388,6 +374,36 @@ loss.backward()
388
374
  - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
389
375
  - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
390
376
 
377
+
378
+ ## CI status
379
+
380
+ <table style="width: 100%; text-align: center; border-collapse: collapse;">
381
+ <tr>
382
+ <th style="padding: 10px;">Build</th>
383
+ </tr>
384
+ <tr>
385
+ <td style="padding: 10px;">
386
+ <div style="display: block;">
387
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
388
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
389
+ </a>
390
+ </div>
391
+ <div style="display: block;">
392
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
393
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
394
+ </a>
395
+ </div>
396
+ <div style="display: block;">
397
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
398
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
399
+ </a>
400
+ </div>
401
+ </td>
402
+ </tr>
403
+ </table>
404
+
405
+
406
+
391
407
  ## Contact
392
408
 
393
409
  - For issues, create a Github ticket in this repository