liger-kernel-nightly 0.5.8.dev20250429233059__py3-none-any.whl → 0.5.8.dev20250502215739__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
26
26
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
27
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
28
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
29
30
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -79,6 +80,7 @@ def __getattr__(name: str):
79
80
  "apply_liger_kernel_to_gemma2",
80
81
  "apply_liger_kernel_to_gemma3",
81
82
  "apply_liger_kernel_to_gemma3_text",
83
+ "apply_liger_kernel_to_glm4",
82
84
  "apply_liger_kernel_to_granite",
83
85
  "apply_liger_kernel_to_llama",
84
86
  "apply_liger_kernel_to_llava",
@@ -129,6 +131,7 @@ if _TRANSFORMERS_AVAILABLE:
129
131
  "apply_liger_kernel_to_gemma2",
130
132
  "apply_liger_kernel_to_gemma3",
131
133
  "apply_liger_kernel_to_gemma3_text",
134
+ "apply_liger_kernel_to_glm4",
132
135
  "apply_liger_kernel_to_granite",
133
136
  "apply_liger_kernel_to_llama",
134
137
  "apply_liger_kernel_to_llava",
@@ -0,0 +1,123 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
+ from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
14
+
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+
17
+
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
+ @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
+ def lce_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
35
+ **loss_kwargs,
36
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ r"""
38
+ Args:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
50
+
51
+ Returns:
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
57
+
58
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
59
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
60
+
61
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
62
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
63
+
64
+ >>> # Generate
65
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
66
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
67
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
68
+ ```
69
+ """
70
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
77
+ outputs = self.model(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ hidden_states = outputs[0]
91
+
92
+ shift_labels = loss_kwargs.pop("shift_labels", None)
93
+ logits = None
94
+ loss = None
95
+ # if in training mode, don't materialize logits
96
+ if self.training and (labels is not None or shift_labels is not None):
97
+ loss = LigerForCausalLMLoss(
98
+ hidden_states=hidden_states,
99
+ lm_head_weight=self.lm_head.weight,
100
+ labels=labels,
101
+ shift_labels=shift_labels,
102
+ hidden_size=self.config.hidden_size,
103
+ **loss_kwargs,
104
+ )
105
+
106
+ else: # if in inference mode materialize logits
107
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
108
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
109
+ if labels is not None:
110
+ loss = self.loss_function(
111
+ logits=logits,
112
+ labels=labels,
113
+ vocab_size=self.config.vocab_size,
114
+ **loss_kwargs,
115
+ )
116
+
117
+ return CausalLMOutputWithPast(
118
+ loss=loss,
119
+ logits=logits,
120
+ past_key_values=outputs.past_key_values,
121
+ hidden_states=outputs.hidden_states,
122
+ attentions=outputs.attentions,
123
+ )
@@ -17,6 +17,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
17
17
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
18
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
20
21
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
22
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
23
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -1319,12 +1320,76 @@ def apply_liger_kernel_to_olmo2(
1319
1320
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1320
1321
 
1321
1322
 
1323
+ def apply_liger_kernel_to_glm4(
1324
+ rope: bool = False,
1325
+ cross_entropy: bool = False,
1326
+ fused_linear_cross_entropy: bool = True,
1327
+ rms_norm: bool = True,
1328
+ swiglu: bool = True,
1329
+ model: PreTrainedModel = None,
1330
+ ) -> None:
1331
+ """
1332
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1333
+
1334
+ Args:
1335
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1336
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1337
+ fused_linear_cross_entropy (bool):
1338
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1339
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1340
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1341
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1342
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1343
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1344
+ loaded. Default is None.
1345
+ """
1346
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1347
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1348
+ )
1349
+
1350
+ from transformers.models.glm4 import modeling_glm4
1351
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
1352
+
1353
+ if rope:
1354
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1355
+ if rms_norm:
1356
+ modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1357
+ if swiglu:
1358
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1359
+ if cross_entropy:
1360
+ from transformers.loss.loss_utils import nn
1361
+
1362
+ nn.functional.cross_entropy = liger_cross_entropy
1363
+ if fused_linear_cross_entropy:
1364
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1365
+
1366
+ if model is not None:
1367
+ # The model instance already exists, so we need to additionally patch the
1368
+ # instance variables that reference already-instantiated modules
1369
+
1370
+ # get the base model from the model instance
1371
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
1372
+
1373
+ if rms_norm:
1374
+ _patch_rms_norm_module(base_model.norm, in_place=False)
1375
+
1376
+ for decoder_layer in base_model.layers:
1377
+ if swiglu:
1378
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1379
+ if rms_norm:
1380
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
1381
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1382
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
1383
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1384
+
1385
+
1322
1386
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1323
1387
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1324
1388
  "gemma": apply_liger_kernel_to_gemma,
1325
1389
  "gemma2": apply_liger_kernel_to_gemma2,
1326
1390
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
1391
  "gemma3": apply_liger_kernel_to_gemma3,
1392
+ "glm4": apply_liger_kernel_to_glm4,
1328
1393
  "llama": apply_liger_kernel_to_llama,
1329
1394
  "llava": apply_liger_kernel_to_llava,
1330
1395
  "granite": apply_liger_kernel_to_granite,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250429233059
3
+ Version: 0.5.8.dev20250502215739
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -320,6 +320,7 @@ loss.backward()
320
320
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
321
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
322
322
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
324
 
324
325
 
325
326
  ## Low-level APIs
@@ -33,7 +33,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
33
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
34
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
35
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
36
- liger_kernel/transformers/__init__.py,sha256=SH30Pt2ZqyQY-mmWQldg_r-5koowuymTIoU4F4e1KHk,6419
36
+ liger_kernel/transformers/__init__.py,sha256=sLAZ_8IxBuim06ZW96OzH1wSsOl5uXvD_OIW6vqOQUQ,6595
37
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
38
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
39
39
  liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
@@ -46,7 +46,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
46
46
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
47
47
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
48
48
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
49
- liger_kernel/transformers/monkey_patch.py,sha256=QpfNU7MmVDGlBWIZ2RLTSyh0vuZ-si7H37SL-qOliUs,64393
49
+ liger_kernel/transformers/monkey_patch.py,sha256=G_6NyTO4jOV2lKuu8zhrjIf0L-QFuNw_T3dmukqyyzk,67381
50
50
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
51
51
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
52
52
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -58,6 +58,7 @@ liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
58
58
  liger_kernel/transformers/model/gemma.py,sha256=uoZvur13XSvtUfiBIP25ZJXEGh4hB5KlB-fq_wpbavY,9940
59
59
  liger_kernel/transformers/model/gemma2.py,sha256=4sPxsnFVywZiNsOoxFM4nEAKB5m5_efnJR7pCEVsQw4,11047
60
60
  liger_kernel/transformers/model/gemma3.py,sha256=wGSNqaLRRgIGQ_r9esyhDezm2SkAGZflopoWoWR-nYY,16226
61
+ liger_kernel/transformers/model/glm4.py,sha256=E_k2FScBW5TvMCznlHVvLGySoeSAn5gO0Nv3zMmK3xM,5305
61
62
  liger_kernel/transformers/model/llama.py,sha256=7AQROxICv2oKSrf5fGJifz_vyuPBkGRXbm0xipUwQew,10617
62
63
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
64
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
@@ -74,9 +75,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
75
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
76
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
77
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/METADATA,sha256=M3ZnXyCzfuYgFnBj7dbF6_i9YJ3OdWrRQDrbTkBB8rs,23297
79
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.8.dev20250429233059.dist-info/RECORD,,
78
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
79
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/METADATA,sha256=WqdvDSWKWaKeFufQ8JrHxF31aTisXKM6eYgwewUFpik,23437
80
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
81
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
82
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
83
+ liger_kernel_nightly-0.5.8.dev20250502215739.dist-info/RECORD,,