liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.5.10.dev20250629005644__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/monkey_patch.py +88 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/METADATA +2 -1
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/RECORD +9 -8
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.5.10.dev20250629005644.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|
30
30
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
31
31
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
32
32
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
33
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
|
33
34
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
34
35
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
35
36
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
@@ -87,6 +88,7 @@ def __getattr__(name: str):
|
|
87
88
|
"apply_liger_kernel_to_granite",
|
88
89
|
"apply_liger_kernel_to_llama",
|
89
90
|
"apply_liger_kernel_to_llava",
|
91
|
+
"apply_liger_kernel_to_llama4",
|
90
92
|
"apply_liger_kernel_to_mistral",
|
91
93
|
"apply_liger_kernel_to_mixtral",
|
92
94
|
"apply_liger_kernel_to_mllama",
|
@@ -141,6 +143,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
141
143
|
"apply_liger_kernel_to_granite",
|
142
144
|
"apply_liger_kernel_to_llama",
|
143
145
|
"apply_liger_kernel_to_llava",
|
146
|
+
"apply_liger_kernel_to_llama4",
|
144
147
|
"apply_liger_kernel_to_mistral",
|
145
148
|
"apply_liger_kernel_to_mixtral",
|
146
149
|
"apply_liger_kernel_to_mllama",
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from typing import List
|
2
|
+
from typing import Optional
|
3
|
+
from typing import Tuple
|
4
|
+
from typing import Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from transformers.cache_utils import Cache
|
9
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10
|
+
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
12
|
+
|
13
|
+
|
14
|
+
def lce_forward(
|
15
|
+
self,
|
16
|
+
input_ids: torch.LongTensor = None,
|
17
|
+
attention_mask: Optional[torch.Tensor] = None,
|
18
|
+
position_ids: Optional[torch.LongTensor] = None,
|
19
|
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
20
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
21
|
+
labels: Optional[torch.LongTensor] = None,
|
22
|
+
use_cache: Optional[bool] = None,
|
23
|
+
output_attentions: Optional[bool] = None,
|
24
|
+
output_hidden_states: Optional[bool] = None,
|
25
|
+
return_dict: Optional[bool] = None,
|
26
|
+
cache_position: Optional[torch.LongTensor] = None,
|
27
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
28
|
+
**kwargs,
|
29
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
30
|
+
r"""
|
31
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
32
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
33
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
34
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
35
|
+
|
36
|
+
Example:
|
37
|
+
|
38
|
+
```python
|
39
|
+
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
40
|
+
|
41
|
+
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
42
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
43
|
+
|
44
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
45
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
46
|
+
|
47
|
+
>>> # Generate
|
48
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
49
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
50
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
51
|
+
```"""
|
52
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
53
|
+
output_hidden_states = (
|
54
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
55
|
+
)
|
56
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
57
|
+
|
58
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
59
|
+
outputs = self.model(
|
60
|
+
input_ids=input_ids,
|
61
|
+
attention_mask=attention_mask,
|
62
|
+
position_ids=position_ids,
|
63
|
+
past_key_values=past_key_values,
|
64
|
+
inputs_embeds=inputs_embeds,
|
65
|
+
use_cache=use_cache,
|
66
|
+
output_attentions=output_attentions,
|
67
|
+
output_hidden_states=output_hidden_states,
|
68
|
+
return_dict=True,
|
69
|
+
cache_position=cache_position,
|
70
|
+
**kwargs,
|
71
|
+
)
|
72
|
+
|
73
|
+
hidden_states = outputs[0]
|
74
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
75
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
76
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
77
|
+
|
78
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
79
|
+
logits = None
|
80
|
+
loss = None
|
81
|
+
|
82
|
+
if self.training and (labels is not None or shift_labels is not None):
|
83
|
+
loss = LigerForCausalLMLoss(
|
84
|
+
hidden_states=kept_hidden_states,
|
85
|
+
lm_head_weight=self.lm_head.weight,
|
86
|
+
labels=labels,
|
87
|
+
shift_labels=shift_labels,
|
88
|
+
hidden_size=self.config.hidden_size,
|
89
|
+
**kwargs,
|
90
|
+
)
|
91
|
+
|
92
|
+
else: # if in inference mode materialize logits
|
93
|
+
logits = self.lm_head(kept_hidden_states)
|
94
|
+
if labels is not None:
|
95
|
+
loss = self.loss_function(
|
96
|
+
logits=logits,
|
97
|
+
labels=labels,
|
98
|
+
vocab_size=self.config.vocab_size,
|
99
|
+
**kwargs,
|
100
|
+
)
|
101
|
+
|
102
|
+
return CausalLMOutputWithPast(
|
103
|
+
loss=loss,
|
104
|
+
logits=logits,
|
105
|
+
past_key_values=outputs.past_key_values,
|
106
|
+
hidden_states=outputs.hidden_states,
|
107
|
+
attentions=outputs.attentions,
|
108
|
+
)
|
@@ -363,6 +363,92 @@ def apply_liger_kernel_to_llava(
|
|
363
363
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
364
364
|
|
365
365
|
|
366
|
+
def apply_liger_kernel_to_llama4(
|
367
|
+
rope: bool = False,
|
368
|
+
cross_entropy: bool = False,
|
369
|
+
fused_linear_cross_entropy: bool = True,
|
370
|
+
rms_norm: bool = True,
|
371
|
+
swiglu: bool = True,
|
372
|
+
model: PreTrainedModel = None,
|
373
|
+
layer_norm: bool = True,
|
374
|
+
) -> None:
|
375
|
+
"""
|
376
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
380
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
381
|
+
fused_linear_cross_entropy (bool):
|
382
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
383
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
384
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
385
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
386
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
387
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
388
|
+
loaded. Default is None.
|
389
|
+
"""
|
390
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
391
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
392
|
+
)
|
393
|
+
|
394
|
+
from transformers.models.llama4 import modeling_llama4
|
395
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
396
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
397
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
398
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
399
|
+
|
400
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
401
|
+
|
402
|
+
if rope:
|
403
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
|
404
|
+
if rms_norm:
|
405
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
406
|
+
if swiglu:
|
407
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
408
|
+
|
409
|
+
if cross_entropy:
|
410
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
411
|
+
|
412
|
+
if fused_linear_cross_entropy:
|
413
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
414
|
+
|
415
|
+
if model is not None:
|
416
|
+
# The model instance already exists, so we need to additionally patch the
|
417
|
+
# instance variables that reference already-instantiated modules
|
418
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
419
|
+
language_model: Llama4ForCausalLM = model.language_model
|
420
|
+
vision_model: Llama4VisionModel = model.vision_model
|
421
|
+
text_model: Llama4TextModel = language_model.model
|
422
|
+
elif isinstance(model, Llama4ForCausalLM):
|
423
|
+
text_model = model.model
|
424
|
+
vision_model = None
|
425
|
+
elif isinstance(model, Llama4TextModel):
|
426
|
+
text_model = model
|
427
|
+
vision_model = None
|
428
|
+
|
429
|
+
else:
|
430
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
431
|
+
|
432
|
+
if text_model:
|
433
|
+
if rms_norm:
|
434
|
+
_patch_rms_norm_module(text_model.norm)
|
435
|
+
for decoder_layer in text_model.layers:
|
436
|
+
if swiglu:
|
437
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
438
|
+
if rms_norm:
|
439
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
440
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
441
|
+
|
442
|
+
if vision_model:
|
443
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
444
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
445
|
+
|
446
|
+
for layer in vision_model.model.layers:
|
447
|
+
if layer_norm:
|
448
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
449
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
450
|
+
|
451
|
+
|
366
452
|
def apply_liger_kernel_to_mllama(
|
367
453
|
rope: bool = True,
|
368
454
|
cross_entropy: bool = False,
|
@@ -1605,6 +1691,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1605
1691
|
"gemma3": apply_liger_kernel_to_gemma3,
|
1606
1692
|
"glm4": apply_liger_kernel_to_glm4,
|
1607
1693
|
"llama": apply_liger_kernel_to_llama,
|
1694
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
1695
|
+
"llama4": apply_liger_kernel_to_llama4,
|
1608
1696
|
"llava": apply_liger_kernel_to_llava,
|
1609
1697
|
"granite": apply_liger_kernel_to_granite,
|
1610
1698
|
"mllama": apply_liger_kernel_to_mllama,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.10.
|
3
|
+
Version: 0.5.10.dev20250629005644
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -290,6 +290,7 @@ loss.backward()
|
|
290
290
|
|
291
291
|
| **Model** | **API** | **Supported Operations** |
|
292
292
|
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
|
293
|
+
| Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
293
294
|
| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
294
295
|
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
295
296
|
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
@@ -38,7 +38,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
38
38
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
39
39
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
40
40
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
41
|
-
liger_kernel/transformers/__init__.py,sha256=
|
41
|
+
liger_kernel/transformers/__init__.py,sha256=mWMEhOabqUkPimMOmkg9DawnO-vL9u_u-N4iIqfNZeg,7259
|
42
42
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
43
43
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
44
44
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
@@ -53,7 +53,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
|
|
53
53
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
54
54
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
55
55
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
56
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
56
|
+
liger_kernel/transformers/monkey_patch.py,sha256=3KqEl_-WlXgUoEAEYgGs-SPolASshGem2ISFemzQAIc,81705
|
57
57
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
58
58
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
59
59
|
liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
|
@@ -70,6 +70,7 @@ liger_kernel/transformers/model/gemma2.py,sha256=ORmzklEAMpk93nToRo4d_ZJbM4ScVE2
|
|
70
70
|
liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
|
71
71
|
liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
|
72
72
|
liger_kernel/transformers/model/llama.py,sha256=LcIxVfF0PXXWHBVJa6Ody_5fAtIpxQcI4jC_j-o51fU,12503
|
73
|
+
liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
|
73
74
|
liger_kernel/transformers/model/llava.py,sha256=bLCioday_SOm69ogMDBhy_4UsVkH2-BSl93-EXY6-7I,15076
|
74
75
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
75
76
|
liger_kernel/transformers/model/mistral.py,sha256=okKkyashfFLfhjIT--f3JY6JHOslOtDI8U1dlpBC2Zs,5565
|
@@ -87,9 +88,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
87
88
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
88
89
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
89
90
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
95
|
-
liger_kernel_nightly-0.5.10.
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/METADATA,sha256=FMeKbXVH-02gQ_G0kVMIc6ftN9rv5WeQZ94Br45A9ek,24536
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
96
|
+
liger_kernel_nightly-0.5.10.dev20250629005644.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|