liger-kernel 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +21 -13
- liger_kernel/ops/layer_norm.py +126 -89
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/transformers/__init__.py +20 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +7 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/gemma3.py +1 -1
- liger_kernel/transformers/model/glm4v.py +150 -0
- liger_kernel/transformers/model/loss_utils.py +2 -0
- liger_kernel/transformers/model/mllama.py +4 -2
- liger_kernel/transformers/model/phi3.py +8 -159
- liger_kernel/transformers/model/smollm3.py +189 -0
- liger_kernel/transformers/monkey_patch.py +185 -32
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/METADATA +12 -14
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/RECORD +26 -19
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.0.dist-info → liger_kernel-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
12
|
+
|
|
13
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
|
+
from liger_kernel.utils import PEFT_AVAILABLE
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from transformers.cache_utils import Cache
|
|
19
|
+
|
|
20
|
+
if PEFT_AVAILABLE:
|
|
21
|
+
from peft.utils.other import ModulesToSaveWrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
25
|
+
def lce_forward(
|
|
26
|
+
self,
|
|
27
|
+
input_ids: torch.LongTensor = None,
|
|
28
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
29
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
30
|
+
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
|
|
31
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
32
|
+
labels: Optional[torch.LongTensor] = None,
|
|
33
|
+
use_cache: Optional[bool] = None,
|
|
34
|
+
output_attentions: Optional[bool] = None,
|
|
35
|
+
output_hidden_states: Optional[bool] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
39
|
+
skip_logits: Optional[bool] = None,
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
42
|
+
r"""
|
|
43
|
+
Args:
|
|
44
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
45
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
46
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
47
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
48
|
+
|
|
49
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
50
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
51
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
52
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
53
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
54
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
>>> from transformers import AutoTokenizer, Smollm3ForCausalLM
|
|
62
|
+
|
|
63
|
+
>>> model = Smollm3ForCausalLM.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
|
64
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
|
65
|
+
|
|
66
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
67
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
68
|
+
|
|
69
|
+
>>> # Generate
|
|
70
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
71
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
72
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
73
|
+
```"""
|
|
74
|
+
|
|
75
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
76
|
+
output_hidden_states = (
|
|
77
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
78
|
+
)
|
|
79
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
80
|
+
|
|
81
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
82
|
+
outputs = self.model(
|
|
83
|
+
input_ids=input_ids,
|
|
84
|
+
attention_mask=attention_mask,
|
|
85
|
+
position_ids=position_ids,
|
|
86
|
+
past_key_values=past_key_values,
|
|
87
|
+
inputs_embeds=inputs_embeds,
|
|
88
|
+
use_cache=use_cache,
|
|
89
|
+
output_attentions=output_attentions,
|
|
90
|
+
output_hidden_states=output_hidden_states,
|
|
91
|
+
return_dict=return_dict,
|
|
92
|
+
cache_position=cache_position,
|
|
93
|
+
**kwargs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
hidden_states = outputs[0]
|
|
97
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
98
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
99
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
100
|
+
|
|
101
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
102
|
+
logits = None
|
|
103
|
+
loss = None
|
|
104
|
+
# if in training mode, don't materialize logits
|
|
105
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
106
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
107
|
+
|
|
108
|
+
if skip_logits is None:
|
|
109
|
+
# By default, if in training mode, don't materialize logits
|
|
110
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
111
|
+
|
|
112
|
+
if skip_logits:
|
|
113
|
+
loss = lce_maybe_trainable_lm_head(
|
|
114
|
+
self,
|
|
115
|
+
hidden_states=kept_hidden_states,
|
|
116
|
+
hidden_size=self.config.hidden_size,
|
|
117
|
+
labels=labels,
|
|
118
|
+
shift_labels=shift_labels,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
logits = self.lm_head(kept_hidden_states)
|
|
124
|
+
if labels is not None:
|
|
125
|
+
loss = self.loss_function(
|
|
126
|
+
logits=logits,
|
|
127
|
+
labels=labels,
|
|
128
|
+
vocab_size=self.config.vocab_size,
|
|
129
|
+
**kwargs,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if not return_dict:
|
|
133
|
+
output = (logits,) + outputs[1:]
|
|
134
|
+
return (loss,) + output if loss is not None else output
|
|
135
|
+
|
|
136
|
+
return CausalLMOutputWithPast(
|
|
137
|
+
loss=loss,
|
|
138
|
+
logits=logits,
|
|
139
|
+
past_key_values=outputs.past_key_values,
|
|
140
|
+
hidden_states=outputs.hidden_states,
|
|
141
|
+
attentions=outputs.attentions,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
146
|
+
lm_head = self.lm_head
|
|
147
|
+
|
|
148
|
+
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
|
|
149
|
+
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
|
|
150
|
+
# from the unwrapped module.
|
|
151
|
+
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
|
|
152
|
+
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
|
|
153
|
+
lm_head = lm_head.modules_to_save.default
|
|
154
|
+
|
|
155
|
+
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
|
|
156
|
+
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
|
|
157
|
+
# so the module entire parameters are summoned and kept in memory during the kernel execution.
|
|
158
|
+
if isinstance(lm_head, FullyShardedDataParallel):
|
|
159
|
+
return _FSDPForwardRedirection()(
|
|
160
|
+
lm_head,
|
|
161
|
+
_liger_for_causal_lm_loss,
|
|
162
|
+
lm_head.module,
|
|
163
|
+
hidden_states,
|
|
164
|
+
hidden_size,
|
|
165
|
+
labels,
|
|
166
|
+
shift_labels,
|
|
167
|
+
**loss_kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# FSDP is not used so we can read the lm_head weights and call the kernel directly
|
|
171
|
+
return _liger_for_causal_lm_loss(
|
|
172
|
+
lm_head=self.lm_head,
|
|
173
|
+
hidden_states=hidden_states,
|
|
174
|
+
hidden_size=hidden_size,
|
|
175
|
+
labels=labels,
|
|
176
|
+
shift_labels=shift_labels,
|
|
177
|
+
**loss_kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
182
|
+
return LigerForCausalLMLoss(
|
|
183
|
+
hidden_states=hidden_states,
|
|
184
|
+
lm_head_weight=lm_head.weight,
|
|
185
|
+
labels=labels,
|
|
186
|
+
hidden_size=hidden_size,
|
|
187
|
+
shift_labels=shift_labels,
|
|
188
|
+
**loss_kwargs,
|
|
189
|
+
)
|
|
@@ -26,9 +26,9 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
|
|
|
26
26
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
27
27
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
28
28
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
29
|
-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
30
29
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
31
30
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
31
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
32
32
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
33
33
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
34
34
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
@@ -77,8 +77,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
77
77
|
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
78
78
|
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
79
79
|
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
80
|
-
module.modules_to_save.default
|
|
81
|
-
module.original_module
|
|
80
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
81
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
82
82
|
else:
|
|
83
83
|
module.offset = offset
|
|
84
84
|
module.casting_mode = casting_mode
|
|
@@ -87,7 +87,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
87
87
|
module.row_mode = row_mode
|
|
88
88
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
89
89
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
90
|
-
module
|
|
90
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
@@ -109,28 +109,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
109
109
|
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
110
110
|
module, "normalized_shape", None
|
|
111
111
|
)
|
|
112
|
-
_bind_method_to_module(module.modules_to_save.default, "forward",
|
|
113
|
-
_bind_method_to_module(module.modules_to_save.default, "extra_repr",
|
|
114
|
-
_bind_method_to_module(module.original_module, "forward",
|
|
115
|
-
_bind_method_to_module(module.original_module, "extra_repr",
|
|
116
|
-
module.modules_to_save.default
|
|
117
|
-
module.original_module
|
|
112
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
113
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
114
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
115
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
117
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
118
118
|
else:
|
|
119
119
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
120
120
|
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
121
121
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
122
122
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
123
|
-
module
|
|
123
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
def _patch_swiglu_module(module, liger_module):
|
|
127
127
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
128
|
-
module
|
|
128
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
129
129
|
|
|
130
130
|
|
|
131
131
|
def _patch_geglu_module(module):
|
|
132
132
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
133
|
-
module
|
|
133
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
134
134
|
|
|
135
135
|
|
|
136
136
|
def apply_liger_kernel_to_granite(
|
|
@@ -290,6 +290,77 @@ def apply_liger_kernel_to_llama(
|
|
|
290
290
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
291
291
|
|
|
292
292
|
|
|
293
|
+
def apply_liger_kernel_to_smollm3(
|
|
294
|
+
rope: bool = True,
|
|
295
|
+
cross_entropy: bool = False,
|
|
296
|
+
fused_linear_cross_entropy: bool = True,
|
|
297
|
+
rms_norm: bool = True,
|
|
298
|
+
swiglu: bool = True,
|
|
299
|
+
model: PreTrainedModel = None,
|
|
300
|
+
) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
306
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
307
|
+
fused_linear_cross_entropy (bool):
|
|
308
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
309
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
310
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
311
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
312
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
313
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
314
|
+
loaded. Default is None.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
318
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
322
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
323
|
+
|
|
324
|
+
if rope:
|
|
325
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
326
|
+
if rms_norm:
|
|
327
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
328
|
+
if swiglu:
|
|
329
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
330
|
+
|
|
331
|
+
if cross_entropy:
|
|
332
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
333
|
+
from transformers.loss.loss_utils import nn
|
|
334
|
+
|
|
335
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
336
|
+
else:
|
|
337
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
338
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
339
|
+
|
|
340
|
+
if fused_linear_cross_entropy:
|
|
341
|
+
if model is not None:
|
|
342
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
343
|
+
else:
|
|
344
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
345
|
+
|
|
346
|
+
if model is not None:
|
|
347
|
+
# The model instance already exists, so we need to additionally patch the
|
|
348
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
349
|
+
|
|
350
|
+
# get the base model from the model instance
|
|
351
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
352
|
+
|
|
353
|
+
if rms_norm:
|
|
354
|
+
_patch_rms_norm_module(base_model.norm)
|
|
355
|
+
|
|
356
|
+
for decoder_layer in base_model.layers:
|
|
357
|
+
if swiglu:
|
|
358
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
359
|
+
if rms_norm:
|
|
360
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
361
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
362
|
+
|
|
363
|
+
|
|
293
364
|
def apply_liger_kernel_to_llava(
|
|
294
365
|
cross_entropy: bool = False,
|
|
295
366
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -377,7 +448,7 @@ def apply_liger_kernel_to_llava(
|
|
|
377
448
|
|
|
378
449
|
|
|
379
450
|
def apply_liger_kernel_to_llama4(
|
|
380
|
-
rope: bool =
|
|
451
|
+
rope: bool = True,
|
|
381
452
|
cross_entropy: bool = False,
|
|
382
453
|
fused_linear_cross_entropy: bool = True,
|
|
383
454
|
rms_norm: bool = True,
|
|
@@ -413,7 +484,9 @@ def apply_liger_kernel_to_llama4(
|
|
|
413
484
|
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
414
485
|
|
|
415
486
|
if rope:
|
|
416
|
-
|
|
487
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
488
|
+
|
|
489
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
417
490
|
if rms_norm:
|
|
418
491
|
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
419
492
|
if swiglu:
|
|
@@ -1603,25 +1676,14 @@ def apply_liger_kernel_to_phi3(
|
|
|
1603
1676
|
if swiglu:
|
|
1604
1677
|
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1605
1678
|
if cross_entropy:
|
|
1606
|
-
|
|
1607
|
-
from transformers.loss.loss_utils import nn
|
|
1679
|
+
from transformers.loss.loss_utils import nn
|
|
1608
1680
|
|
|
1609
|
-
|
|
1610
|
-
else:
|
|
1611
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1612
|
-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1681
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1613
1682
|
if fused_linear_cross_entropy:
|
|
1614
|
-
if
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1619
|
-
else: # if version < 4.46.1
|
|
1620
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1621
|
-
if model is not None:
|
|
1622
|
-
model.forward = MethodType(phi3_lce_forward_deprecated, model)
|
|
1623
|
-
else:
|
|
1624
|
-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
1683
|
+
if model is not None:
|
|
1684
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1685
|
+
else:
|
|
1686
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1625
1687
|
|
|
1626
1688
|
if model is not None:
|
|
1627
1689
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1777,6 +1839,95 @@ def apply_liger_kernel_to_glm4(
|
|
|
1777
1839
|
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
1778
1840
|
|
|
1779
1841
|
|
|
1842
|
+
def apply_liger_kernel_to_glm4v(
|
|
1843
|
+
rope: bool = False,
|
|
1844
|
+
cross_entropy: bool = False,
|
|
1845
|
+
fused_linear_cross_entropy: bool = True,
|
|
1846
|
+
rms_norm: bool = True,
|
|
1847
|
+
swiglu: bool = True,
|
|
1848
|
+
model: PreTrainedModel = None,
|
|
1849
|
+
) -> None:
|
|
1850
|
+
"""
|
|
1851
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
1852
|
+
|
|
1853
|
+
Args:
|
|
1854
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
1855
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1856
|
+
fused_linear_cross_entropy (bool):
|
|
1857
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1858
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1859
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1860
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1861
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
1862
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1863
|
+
loaded. Default is None.
|
|
1864
|
+
"""
|
|
1865
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1866
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1867
|
+
)
|
|
1868
|
+
|
|
1869
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
1870
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
1871
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
1872
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
1873
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
1874
|
+
|
|
1875
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
1876
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
1877
|
+
|
|
1878
|
+
if rope:
|
|
1879
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
1880
|
+
if rms_norm:
|
|
1881
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
1882
|
+
if cross_entropy:
|
|
1883
|
+
from transformers.loss.loss_utils import nn
|
|
1884
|
+
|
|
1885
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1886
|
+
if fused_linear_cross_entropy:
|
|
1887
|
+
if model is not None:
|
|
1888
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
1889
|
+
else:
|
|
1890
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
1891
|
+
|
|
1892
|
+
if model is not None:
|
|
1893
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1894
|
+
# instance variables that reference already-instantiated modules
|
|
1895
|
+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
|
|
1896
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1897
|
+
# Not sure if it is subject to changes in the future.
|
|
1898
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
|
|
1899
|
+
text_model: Glm4vTextModel = model.language_model
|
|
1900
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
1901
|
+
elif isinstance(model, Glm4vTextModel):
|
|
1902
|
+
text_model: Glm4vTextModel = model
|
|
1903
|
+
vision_model = None
|
|
1904
|
+
else:
|
|
1905
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1906
|
+
raise TypeError(
|
|
1907
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
1908
|
+
)
|
|
1909
|
+
|
|
1910
|
+
if vision_model is not None:
|
|
1911
|
+
for vision_block in vision_model.blocks:
|
|
1912
|
+
if rms_norm:
|
|
1913
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1914
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1915
|
+
if swiglu:
|
|
1916
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
1917
|
+
|
|
1918
|
+
if text_model is not None:
|
|
1919
|
+
if rms_norm:
|
|
1920
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1921
|
+
for decoder_layer in text_model.layers:
|
|
1922
|
+
if swiglu:
|
|
1923
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1924
|
+
if rms_norm:
|
|
1925
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1926
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1927
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
1928
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
1929
|
+
|
|
1930
|
+
|
|
1780
1931
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
1781
1932
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1782
1933
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -1784,6 +1935,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1784
1935
|
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1785
1936
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1786
1937
|
"glm4": apply_liger_kernel_to_glm4,
|
|
1938
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
1787
1939
|
"llama": apply_liger_kernel_to_llama,
|
|
1788
1940
|
"llama4_text": apply_liger_kernel_to_llama4,
|
|
1789
1941
|
"llama4": apply_liger_kernel_to_llama4,
|
|
@@ -1801,6 +1953,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1801
1953
|
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1802
1954
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1803
1955
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
1956
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1804
1957
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1805
1958
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
1806
1959
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.2
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -84,7 +84,7 @@ Dynamic: requires-dist
|
|
|
84
84
|
</td>
|
|
85
85
|
<td style="padding: 10px;">
|
|
86
86
|
<a href="https://discord.gg/gpumode">
|
|
87
|
-
<img src="https://dcbadge.
|
|
87
|
+
<img src="https://dcbadge.limes.pink/api/server/gpumode?style=flat" alt="Join Our Discord">
|
|
88
88
|
</a>
|
|
89
89
|
</td>
|
|
90
90
|
</tr>
|
|
@@ -307,7 +307,7 @@ loss.backward()
|
|
|
307
307
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
308
308
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
309
309
|
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
310
|
-
| Qwen3 MoE | `
|
|
310
|
+
| Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
311
311
|
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
312
312
|
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
|
|
313
313
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -400,7 +400,7 @@ loss.backward()
|
|
|
400
400
|
</a>
|
|
401
401
|
</div>
|
|
402
402
|
<div style="display: block;">
|
|
403
|
-
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/
|
|
403
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
|
|
404
404
|
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
|
|
405
405
|
</a>
|
|
406
406
|
</div>
|
|
@@ -414,21 +414,19 @@ loss.backward()
|
|
|
414
414
|
|
|
415
415
|
- For issues, create a Github ticket in this repository
|
|
416
416
|
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
|
|
417
|
-
- For formal collaboration, send an email to yannchen@linkedin.com and
|
|
417
|
+
- For formal collaboration, send an email to Yanning Chen(yannchen@linkedin.com) and Zhipeng Wang(zhipwang@linkedin.com)
|
|
418
418
|
|
|
419
419
|
## Cite this work
|
|
420
420
|
|
|
421
421
|
Biblatex entry:
|
|
422
422
|
```bib
|
|
423
|
-
@
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
url={https://arxiv.org/abs/2410.10989},
|
|
431
|
-
journal={arXiv preprint arXiv:2410.10989},
|
|
423
|
+
@inproceedings{
|
|
424
|
+
hsu2025ligerkernel,
|
|
425
|
+
title={Liger-Kernel: Efficient Triton Kernels for {LLM} Training},
|
|
426
|
+
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen and Zhipeng Wang},
|
|
427
|
+
booktitle={Championing Open-source DEvelopment in ML Workshop @ ICML25},
|
|
428
|
+
year={2025},
|
|
429
|
+
url={https://openreview.net/forum?id=36SjAIT42G}
|
|
432
430
|
}
|
|
433
431
|
```
|
|
434
432
|
|