liger-kernel 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,189 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Tuple
5
+ from typing import Union
6
+
7
+ import torch
8
+
9
+ from torch.distributed.fsdp import FullyShardedDataParallel
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.utils.deprecation import deprecate_kwarg
12
+
13
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
14
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.utils import PEFT_AVAILABLE
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.cache_utils import Cache
19
+
20
+ if PEFT_AVAILABLE:
21
+ from peft.utils.other import ModulesToSaveWrapper
22
+
23
+
24
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ logits_to_keep: Union[int, torch.Tensor] = 0,
39
+ skip_logits: Optional[bool] = None,
40
+ **kwargs,
41
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
42
+ r"""
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
50
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
51
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
52
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
53
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
54
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
55
+
56
+ Returns:
57
+
58
+ Example:
59
+
60
+ ```python
61
+ >>> from transformers import AutoTokenizer, Smollm3ForCausalLM
62
+
63
+ >>> model = Smollm3ForCausalLM.from_pretrained("HuggingFaceTB/SmolLM3-3B")
64
+ >>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
65
+
66
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
67
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
68
+
69
+ >>> # Generate
70
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
71
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
72
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
73
+ ```"""
74
+
75
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76
+ output_hidden_states = (
77
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
+ )
79
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
+
81
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
+ outputs = self.model(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ cache_position=cache_position,
93
+ **kwargs,
94
+ )
95
+
96
+ hidden_states = outputs[0]
97
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
98
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
99
+ kept_hidden_states = hidden_states[:, slice_indices, :]
100
+
101
+ shift_labels = kwargs.pop("shift_labels", None)
102
+ logits = None
103
+ loss = None
104
+ # if in training mode, don't materialize logits
105
+ if skip_logits and labels is None and shift_labels is None:
106
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
107
+
108
+ if skip_logits is None:
109
+ # By default, if in training mode, don't materialize logits
110
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
111
+
112
+ if skip_logits:
113
+ loss = lce_maybe_trainable_lm_head(
114
+ self,
115
+ hidden_states=kept_hidden_states,
116
+ hidden_size=self.config.hidden_size,
117
+ labels=labels,
118
+ shift_labels=shift_labels,
119
+ **kwargs,
120
+ )
121
+
122
+ else:
123
+ logits = self.lm_head(kept_hidden_states)
124
+ if labels is not None:
125
+ loss = self.loss_function(
126
+ logits=logits,
127
+ labels=labels,
128
+ vocab_size=self.config.vocab_size,
129
+ **kwargs,
130
+ )
131
+
132
+ if not return_dict:
133
+ output = (logits,) + outputs[1:]
134
+ return (loss,) + output if loss is not None else output
135
+
136
+ return CausalLMOutputWithPast(
137
+ loss=loss,
138
+ logits=logits,
139
+ past_key_values=outputs.past_key_values,
140
+ hidden_states=outputs.hidden_states,
141
+ attentions=outputs.attentions,
142
+ )
143
+
144
+
145
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
146
+ lm_head = self.lm_head
147
+
148
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
149
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
150
+ # from the unwrapped module.
151
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
152
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
153
+ lm_head = lm_head.modules_to_save.default
154
+
155
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
156
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
157
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
158
+ if isinstance(lm_head, FullyShardedDataParallel):
159
+ return _FSDPForwardRedirection()(
160
+ lm_head,
161
+ _liger_for_causal_lm_loss,
162
+ lm_head.module,
163
+ hidden_states,
164
+ hidden_size,
165
+ labels,
166
+ shift_labels,
167
+ **loss_kwargs,
168
+ )
169
+
170
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
171
+ return _liger_for_causal_lm_loss(
172
+ lm_head=self.lm_head,
173
+ hidden_states=hidden_states,
174
+ hidden_size=hidden_size,
175
+ labels=labels,
176
+ shift_labels=shift_labels,
177
+ **loss_kwargs,
178
+ )
179
+
180
+
181
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
182
+ return LigerForCausalLMLoss(
183
+ hidden_states=hidden_states,
184
+ lm_head_weight=lm_head.weight,
185
+ labels=labels,
186
+ hidden_size=hidden_size,
187
+ shift_labels=shift_labels,
188
+ **loss_kwargs,
189
+ )
@@ -26,9 +26,9 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
26
26
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
27
27
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
28
28
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
29
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
30
29
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
31
30
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
31
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
32
32
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
33
33
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
34
34
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
@@ -77,8 +77,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
77
77
  _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
78
78
  _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
79
79
  _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
80
- module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
81
- module.original_module.__class__.__name__ = LigerRMSNorm.__name__
80
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
81
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
82
82
  else:
83
83
  module.offset = offset
84
84
  module.casting_mode = casting_mode
@@ -87,7 +87,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
87
87
  module.row_mode = row_mode
88
88
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
89
89
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
90
- module.__class__.__name__ = LigerRMSNorm.__name__
90
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
91
91
 
92
92
 
93
93
  def _patch_layer_norm_module(module, eps=1e-6):
@@ -109,28 +109,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
109
109
  module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
110
110
  module, "normalized_shape", None
111
111
  )
112
- _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
113
- _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
114
- _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
115
- _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
116
- module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
117
- module.original_module.__class__.__name__ = LigerLayerNorm.__name__
112
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
113
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
114
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
115
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
116
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
117
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
118
118
  else:
119
119
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
120
120
  module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
121
121
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
122
122
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
123
- module.__class__.__name__ = LigerLayerNorm.__name__
123
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
124
124
 
125
125
 
126
126
  def _patch_swiglu_module(module, liger_module):
127
127
  _bind_method_to_module(module, "forward", liger_module.forward)
128
- module.__class__.__name__ = liger_module.__name__
128
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
129
129
 
130
130
 
131
131
  def _patch_geglu_module(module):
132
132
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
133
- module.__class__.__name__ = LigerGEGLUMLP.__name__
133
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
134
134
 
135
135
 
136
136
  def apply_liger_kernel_to_granite(
@@ -290,6 +290,77 @@ def apply_liger_kernel_to_llama(
290
290
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
291
291
 
292
292
 
293
+ def apply_liger_kernel_to_smollm3(
294
+ rope: bool = True,
295
+ cross_entropy: bool = False,
296
+ fused_linear_cross_entropy: bool = True,
297
+ rms_norm: bool = True,
298
+ swiglu: bool = True,
299
+ model: PreTrainedModel = None,
300
+ ) -> None:
301
+ """
302
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
303
+
304
+ Args:
305
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
306
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
307
+ fused_linear_cross_entropy (bool):
308
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
309
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
310
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
311
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
312
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
313
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
314
+ loaded. Default is None.
315
+ """
316
+
317
+ assert not (cross_entropy and fused_linear_cross_entropy), (
318
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
319
+ )
320
+
321
+ from transformers.models.smollm3 import modeling_smollm3
322
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
323
+
324
+ if rope:
325
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
326
+ if rms_norm:
327
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
328
+ if swiglu:
329
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
330
+
331
+ if cross_entropy:
332
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
333
+ from transformers.loss.loss_utils import nn
334
+
335
+ nn.functional.cross_entropy = liger_cross_entropy
336
+ else:
337
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
338
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
339
+
340
+ if fused_linear_cross_entropy:
341
+ if model is not None:
342
+ model.forward = MethodType(smollm3_lce_forward, model)
343
+ else:
344
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
345
+
346
+ if model is not None:
347
+ # The model instance already exists, so we need to additionally patch the
348
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
349
+
350
+ # get the base model from the model instance
351
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
352
+
353
+ if rms_norm:
354
+ _patch_rms_norm_module(base_model.norm)
355
+
356
+ for decoder_layer in base_model.layers:
357
+ if swiglu:
358
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
359
+ if rms_norm:
360
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
361
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
362
+
363
+
293
364
  def apply_liger_kernel_to_llava(
294
365
  cross_entropy: bool = False,
295
366
  fused_linear_cross_entropy: bool = True,
@@ -377,7 +448,7 @@ def apply_liger_kernel_to_llava(
377
448
 
378
449
 
379
450
  def apply_liger_kernel_to_llama4(
380
- rope: bool = False,
451
+ rope: bool = True,
381
452
  cross_entropy: bool = False,
382
453
  fused_linear_cross_entropy: bool = True,
383
454
  rms_norm: bool = True,
@@ -413,7 +484,9 @@ def apply_liger_kernel_to_llama4(
413
484
  from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
414
485
 
415
486
  if rope:
416
- raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
487
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
488
+
489
+ apply_liger_llama4_rope_full(modeling_llama4)
417
490
  if rms_norm:
418
491
  modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
419
492
  if swiglu:
@@ -1603,25 +1676,14 @@ def apply_liger_kernel_to_phi3(
1603
1676
  if swiglu:
1604
1677
  modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1605
1678
  if cross_entropy:
1606
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1607
- from transformers.loss.loss_utils import nn
1679
+ from transformers.loss.loss_utils import nn
1608
1680
 
1609
- nn.functional.cross_entropy = liger_cross_entropy
1610
- else:
1611
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1612
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1681
+ nn.functional.cross_entropy = liger_cross_entropy
1613
1682
  if fused_linear_cross_entropy:
1614
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1615
- if model is not None:
1616
- model.forward = MethodType(phi3_lce_forward, model)
1617
- else:
1618
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1619
- else: # if version < 4.46.1
1620
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1621
- if model is not None:
1622
- model.forward = MethodType(phi3_lce_forward_deprecated, model)
1623
- else:
1624
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1683
+ if model is not None:
1684
+ model.forward = MethodType(phi3_lce_forward, model)
1685
+ else:
1686
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1625
1687
 
1626
1688
  if model is not None:
1627
1689
  # The model instance already exists, so we need to additionally patch the
@@ -1777,6 +1839,95 @@ def apply_liger_kernel_to_glm4(
1777
1839
  _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1778
1840
 
1779
1841
 
1842
+ def apply_liger_kernel_to_glm4v(
1843
+ rope: bool = False,
1844
+ cross_entropy: bool = False,
1845
+ fused_linear_cross_entropy: bool = True,
1846
+ rms_norm: bool = True,
1847
+ swiglu: bool = True,
1848
+ model: PreTrainedModel = None,
1849
+ ) -> None:
1850
+ """
1851
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
1852
+
1853
+ Args:
1854
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1855
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1856
+ fused_linear_cross_entropy (bool):
1857
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1858
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1859
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1860
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1861
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1862
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1863
+ loaded. Default is None.
1864
+ """
1865
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1866
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1867
+ )
1868
+
1869
+ from transformers.models.glm4v import modeling_glm4v
1870
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
1871
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
1872
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
1873
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
1874
+
1875
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
1876
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1877
+
1878
+ if rope:
1879
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1880
+ if rms_norm:
1881
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
1882
+ if cross_entropy:
1883
+ from transformers.loss.loss_utils import nn
1884
+
1885
+ nn.functional.cross_entropy = liger_cross_entropy
1886
+ if fused_linear_cross_entropy:
1887
+ if model is not None:
1888
+ model.forward = MethodType(glm4v_lce_forward, model)
1889
+ else:
1890
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
1891
+
1892
+ if model is not None:
1893
+ # The model instance already exists, so we need to additionally patch the
1894
+ # instance variables that reference already-instantiated modules
1895
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
1896
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1897
+ # Not sure if it is subject to changes in the future.
1898
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
1899
+ text_model: Glm4vTextModel = model.language_model
1900
+ vision_model: Glm4vVisionModel = model.visual
1901
+ elif isinstance(model, Glm4vTextModel):
1902
+ text_model: Glm4vTextModel = model
1903
+ vision_model = None
1904
+ else:
1905
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1906
+ raise TypeError(
1907
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
1908
+ )
1909
+
1910
+ if vision_model is not None:
1911
+ for vision_block in vision_model.blocks:
1912
+ if rms_norm:
1913
+ _patch_rms_norm_module(vision_block.norm1)
1914
+ _patch_rms_norm_module(vision_block.norm2)
1915
+ if swiglu:
1916
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
1917
+
1918
+ if text_model is not None:
1919
+ if rms_norm:
1920
+ _patch_rms_norm_module(text_model.norm)
1921
+ for decoder_layer in text_model.layers:
1922
+ if swiglu:
1923
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1924
+ if rms_norm:
1925
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1926
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1927
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
1928
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
1929
+
1930
+
1780
1931
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1781
1932
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1782
1933
  "gemma": apply_liger_kernel_to_gemma,
@@ -1784,6 +1935,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1784
1935
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1785
1936
  "gemma3": apply_liger_kernel_to_gemma3,
1786
1937
  "glm4": apply_liger_kernel_to_glm4,
1938
+ "glm4v": apply_liger_kernel_to_glm4v,
1787
1939
  "llama": apply_liger_kernel_to_llama,
1788
1940
  "llama4_text": apply_liger_kernel_to_llama4,
1789
1941
  "llama4": apply_liger_kernel_to_llama4,
@@ -1801,6 +1953,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1801
1953
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1802
1954
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1803
1955
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1956
+ "smollm3": apply_liger_kernel_to_smollm3,
1804
1957
  "phi3": apply_liger_kernel_to_phi3,
1805
1958
  "paligemma": apply_liger_kernel_to_paligemma,
1806
1959
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -84,7 +84,7 @@ Dynamic: requires-dist
84
84
  </td>
85
85
  <td style="padding: 10px;">
86
86
  <a href="https://discord.gg/gpumode">
87
- <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
87
+ <img src="https://dcbadge.limes.pink/api/server/gpumode?style=flat" alt="Join Our Discord">
88
88
  </a>
89
89
  </td>
90
90
  </tr>
@@ -307,7 +307,7 @@ loss.backward()
307
307
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
308
308
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
309
309
  | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
310
- | Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
310
+ | Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
311
311
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
312
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
313
313
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -400,7 +400,7 @@ loss.backward()
400
400
  </a>
401
401
  </div>
402
402
  <div style="display: block;">
403
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
403
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
404
404
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
405
405
  </a>
406
406
  </div>
@@ -414,21 +414,19 @@ loss.backward()
414
414
 
415
415
  - For issues, create a Github ticket in this repository
416
416
  - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
417
- - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
417
+ - For formal collaboration, send an email to Yanning Chen(yannchen@linkedin.com) and Zhipeng Wang(zhipwang@linkedin.com)
418
418
 
419
419
  ## Cite this work
420
420
 
421
421
  Biblatex entry:
422
422
  ```bib
423
- @article{hsu2024ligerkernelefficienttriton,
424
- title={Liger Kernel: Efficient Triton Kernels for LLM Training},
425
- author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
426
- year={2024},
427
- eprint={2410.10989},
428
- archivePrefix={arXiv},
429
- primaryClass={cs.LG},
430
- url={https://arxiv.org/abs/2410.10989},
431
- journal={arXiv preprint arXiv:2410.10989},
423
+ @inproceedings{
424
+ hsu2025ligerkernel,
425
+ title={Liger-Kernel: Efficient Triton Kernels for {LLM} Training},
426
+ author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen and Zhipeng Wang},
427
+ booktitle={Championing Open-source DEvelopment in ML Workshop @ ICML25},
428
+ year={2025},
429
+ url={https://openreview.net/forum?id=36SjAIT42G}
432
430
  }
433
431
  ```
434
432