liger-kernel-nightly 0.5.5.dev20250402212634__py3-none-any.whl → 0.5.6.dev20250403001329__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,8 @@ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa:
10
10
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
11
11
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
12
12
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
13
15
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
14
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
15
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -0,0 +1,8 @@
1
+ from .rms_norm import LigerRMSNorm
2
+
3
+
4
+ class LigerRMSNormForGemma3(LigerRMSNorm):
5
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
+
7
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
@@ -0,0 +1,335 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers.cache_utils import Cache
10
+ from transformers.cache_utils import HybridCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
+ from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
+ from transformers.utils import add_start_docstrings_to_model_forward
16
+ from transformers.utils import is_torchdynamo_compiling
17
+ from transformers.utils import logging
18
+ from transformers.utils import replace_return_docstrings
19
+ from transformers.utils.deprecation import deprecate_kwarg
20
+
21
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
22
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
+ def causal_forward(
31
+ self,
32
+ input_ids: torch.LongTensor = None,
33
+ attention_mask: Optional[torch.Tensor] = None,
34
+ position_ids: Optional[torch.LongTensor] = None,
35
+ past_key_values: Optional[HybridCache] = None,
36
+ inputs_embeds: Optional[torch.FloatTensor] = None,
37
+ labels: Optional[torch.LongTensor] = None,
38
+ use_cache: Optional[bool] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ return_dict: Optional[bool] = None,
42
+ cache_position: Optional[torch.LongTensor] = None,
43
+ logits_to_keep: Union[int, torch.Tensor] = 0,
44
+ **loss_kwargs,
45
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
46
+ r"""
47
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
49
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
50
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
51
+
52
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
53
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
54
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
55
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
56
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
57
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
58
+
59
+ Returns:
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import AutoTokenizer, Gemma3ForCausalLM
65
+
66
+ >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
67
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
68
+
69
+ >>> prompt = "What is your favorite condiment?"
70
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
71
+
72
+ >>> # Generate
73
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
74
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
75
+ "What is your favorite condiment?"
76
+ ```"""
77
+
78
+ if self.training and self.config._attn_implementation != "eager":
79
+ logger.warning_once(
80
+ "It is strongly recommended to train Gemma3 models with the `eager` attention implementation "
81
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
82
+ )
83
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
+ output_hidden_states = (
85
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
+ )
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
89
+ outputs = self.model(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ cache_position=cache_position,
100
+ **loss_kwargs,
101
+ )
102
+
103
+ hidden_states = outputs[0]
104
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
+ kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ loss = None
108
+ logits = None
109
+ if self.training and (labels is not None):
110
+ loss = LigerForCausalLMLoss(
111
+ hidden_states=kept_hidden_states,
112
+ lm_head_weight=self.lm_head.weight,
113
+ labels=labels,
114
+ hidden_size=self.config.hidden_size,
115
+ softcap=self.config.final_logit_softcapping,
116
+ **loss_kwargs,
117
+ )
118
+
119
+ else:
120
+ logits = self.lm_head(kept_hidden_states)
121
+ if self.config.final_logit_softcapping is not None:
122
+ logits = logits / self.config.final_logit_softcapping
123
+ logits = torch.tanh(logits)
124
+ logits = logits * self.config.final_logit_softcapping
125
+ if labels is not None:
126
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
139
+
140
+
141
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
142
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
143
+ @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
+ def multimodal_forward(
145
+ self,
146
+ input_ids: torch.LongTensor = None,
147
+ pixel_values: torch.FloatTensor = None,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ position_ids: Optional[torch.LongTensor] = None,
150
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
151
+ token_type_ids: Optional[torch.LongTensor] = None,
152
+ cache_position: Optional[torch.LongTensor] = None,
153
+ inputs_embeds: Optional[torch.FloatTensor] = None,
154
+ labels: Optional[torch.LongTensor] = None,
155
+ use_cache: Optional[bool] = None,
156
+ output_attentions: Optional[bool] = None,
157
+ output_hidden_states: Optional[bool] = None,
158
+ return_dict: Optional[bool] = None,
159
+ logits_to_keep: Union[int, torch.Tensor] = 0,
160
+ **lm_kwargs,
161
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
162
+ r"""
163
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
167
+
168
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
169
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
170
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
173
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
174
+
175
+ Returns:
176
+
177
+ Example:
178
+
179
+ ```python
180
+ >>> from PIL import Image
181
+ >>> import requests
182
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
183
+
184
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
185
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
186
+
187
+ >>> prompt = "answer en Where is the cow standing?"
188
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
189
+ >>> image = Image.open(requests.get(url, stream=True).raw)
190
+
191
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(**inputs, max_length=30)
195
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "answer en Where is the cow standing?\nbeach"
197
+ ```"""
198
+
199
+ if (input_ids is None) ^ (inputs_embeds is not None):
200
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
201
+
202
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
203
+ output_hidden_states = (
204
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ )
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ is_training = token_type_ids is not None and labels is not None
209
+
210
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
211
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
212
+ special_image_mask = input_ids == self.config.image_token_index
213
+ llm_input_ids = input_ids.clone()
214
+ llm_input_ids[special_image_mask] = 0
215
+ else:
216
+ llm_input_ids = input_ids
217
+
218
+ if inputs_embeds is None:
219
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
220
+
221
+ if cache_position is None:
222
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
223
+ cache_position = torch.arange(
224
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
225
+ )
226
+
227
+ if position_ids is None:
228
+ position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
229
+
230
+ # Merge text and images
231
+ if pixel_values is not None:
232
+ image_features = self.get_image_features(pixel_values)
233
+
234
+ if input_ids is None:
235
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
236
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
237
+ )
238
+ else:
239
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
240
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
241
+
242
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
243
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
244
+ raise ValueError(
245
+ f"Number of images does not match number of special image tokens in the input text. "
246
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
247
+ "tokens from image embeddings."
248
+ )
249
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
250
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
251
+
252
+ # mask out pad-token-ids in labels for BC
253
+ if labels is not None and self.pad_token_id in labels:
254
+ logger.warning_once(
255
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
256
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
257
+ )
258
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
259
+
260
+ causal_mask = self._update_causal_mask(
261
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
262
+ )
263
+ outputs = self.language_model.model(
264
+ attention_mask=causal_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ inputs_embeds=inputs_embeds,
268
+ use_cache=use_cache,
269
+ output_attentions=output_attentions,
270
+ output_hidden_states=output_hidden_states,
271
+ return_dict=return_dict,
272
+ cache_position=cache_position,
273
+ logits_to_keep=logits_to_keep,
274
+ **lm_kwargs,
275
+ )
276
+
277
+ hidden_states = outputs[0]
278
+ loss = None
279
+ logits = None
280
+
281
+ if self.training and (labels is not None):
282
+ shift_hidden_states = hidden_states[..., :-1, :]
283
+ shift_labels = labels[..., 1:]
284
+
285
+ hidden_device = shift_hidden_states.device
286
+ if attention_mask is not None:
287
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
288
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
289
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
290
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
291
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
292
+ else:
293
+ shift_hidden_states = shift_hidden_states.contiguous()
294
+ shift_labels = shift_labels.contiguous()
295
+
296
+ # Flatten hidden state
297
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
298
+ shift_labels = shift_labels.view(-1).to(hidden_device)
299
+
300
+ lce = LigerFusedLinearCrossEntropyLoss()
301
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
302
+ else:
303
+ logits = self.language_model.lm_head(hidden_states)
304
+ if labels is not None:
305
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
306
+ logits = logits.float()
307
+ shift_logits = logits[..., :-1, :]
308
+ shift_labels = labels[..., 1:]
309
+ if attention_mask is not None:
310
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
311
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
312
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
313
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
314
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
315
+ else:
316
+ shift_logits = shift_logits.contiguous()
317
+ shift_labels = shift_labels.contiguous()
318
+ # Flatten the tokens
319
+ loss_fct = nn.CrossEntropyLoss()
320
+
321
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
322
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
323
+ loss = loss_fct(flat_logits, flat_labels)
324
+ if not return_dict:
325
+ output = (logits,) + outputs[1:]
326
+ return (loss,) + output if loss is not None else output
327
+
328
+ return Gemma3CausalLMOutputWithPast(
329
+ loss=loss,
330
+ logits=logits,
331
+ past_key_values=outputs.past_key_values,
332
+ hidden_states=outputs.hidden_states,
333
+ attentions=outputs.attentions,
334
+ image_hidden_states=image_features if pixel_values is not None else None,
335
+ )
@@ -694,6 +694,177 @@ def apply_liger_kernel_to_gemma2(
694
694
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
695
695
 
696
696
 
697
+ def apply_liger_kernel_to_gemma3_text(
698
+ rope: bool = True,
699
+ cross_entropy: bool = False,
700
+ fused_linear_cross_entropy: bool = True,
701
+ rms_norm: bool = True,
702
+ geglu: bool = True,
703
+ model: PreTrainedModel = None,
704
+ ) -> None:
705
+ """
706
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
707
+
708
+ Args:
709
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
710
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
711
+ fused_linear_cross_entropy (bool):
712
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
713
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
714
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
715
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
716
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
717
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
718
+ loaded. Default is None.
719
+ """
720
+ assert not (cross_entropy and fused_linear_cross_entropy), (
721
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
722
+ )
723
+
724
+ from transformers.models.gemma3 import modeling_gemma3
725
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
726
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
727
+
728
+ from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
729
+ from liger_kernel.transformers.model.gemma3 import causal_forward
730
+
731
+ _patch_rms_norm_module_for_gemma3 = partial(
732
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
733
+ )
734
+
735
+ if rope:
736
+ modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
737
+
738
+ if rms_norm:
739
+ modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
740
+
741
+ if geglu:
742
+ modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
743
+
744
+ # Handle loss function
745
+ if cross_entropy:
746
+ from transformers.loss.loss_utils import nn
747
+
748
+ nn.functional.cross_entropy = liger_cross_entropy
749
+
750
+ if fused_linear_cross_entropy:
751
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
752
+
753
+ if model is not None:
754
+ # The model instance already exists, so we need to additionally patch the
755
+ # instance variables that reference already-instantiated modules
756
+
757
+ if isinstance(model, Gemma3ForCausalLM):
758
+ # get the base model from the model instance
759
+ base_model = model.model
760
+
761
+ if rms_norm:
762
+ _patch_rms_norm_module_for_gemma3(base_model.norm)
763
+
764
+ for decoder_layer in base_model.layers:
765
+ decoder_layer: Gemma3DecoderLayer
766
+ if geglu:
767
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
768
+ if rms_norm:
769
+ _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
770
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
771
+ _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
772
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
773
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
774
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
775
+
776
+ else:
777
+ raise TypeError("The model must be Gemma3ForCausalLM.")
778
+
779
+
780
+ def apply_liger_kernel_to_gemma3(
781
+ rope: bool = True,
782
+ cross_entropy: bool = False,
783
+ fused_linear_cross_entropy: bool = True,
784
+ layer_norm: bool = True,
785
+ rms_norm: bool = True,
786
+ geglu: bool = True,
787
+ model: PreTrainedModel = None,
788
+ ) -> None:
789
+ """
790
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
791
+
792
+ Args:
793
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
794
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
795
+ fused_linear_cross_entropy (bool):
796
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
797
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
798
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
799
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
800
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
801
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
802
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
803
+ loaded. Default is None.
804
+ """
805
+ assert not (cross_entropy and fused_linear_cross_entropy), (
806
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
807
+ )
808
+
809
+ from transformers.models.gemma3 import modeling_gemma3
810
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
811
+ from transformers.models.siglip import modeling_siglip
812
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
813
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
814
+
815
+ from liger_kernel.transformers.model.gemma3 import multimodal_forward
816
+
817
+ _patch_rms_norm_module_for_gemma3 = partial(
818
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
819
+ )
820
+
821
+ if layer_norm:
822
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
823
+
824
+ apply_liger_kernel_to_gemma3_text(
825
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
826
+ )
827
+
828
+ if cross_entropy:
829
+ modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
830
+
831
+ if fused_linear_cross_entropy:
832
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
833
+
834
+ if model is not None:
835
+ # The model instance already exists, so we need to additionally patch the
836
+ # instance variables that reference already-instantiated modules
837
+
838
+ if isinstance(model, Gemma3ForConditionalGeneration):
839
+ if isinstance(model.vision_tower, SiglipVisionModel):
840
+ vision_tower = model.vision_tower
841
+
842
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
843
+
844
+ for layer in vision_tower.vision_model.encoder.layers:
845
+ layer: SiglipEncoderLayer
846
+ if layer_norm:
847
+ _patch_layer_norm_module(layer.layer_norm1)
848
+ _patch_layer_norm_module(layer.layer_norm2)
849
+ else:
850
+ raise TypeError("The vision tower must be SiglipVisionModel")
851
+
852
+ if rms_norm:
853
+ _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
854
+
855
+ apply_liger_kernel_to_gemma3_text(
856
+ rope=rope,
857
+ cross_entropy=False,
858
+ fused_linear_cross_entropy=False,
859
+ rms_norm=rms_norm,
860
+ geglu=geglu,
861
+ model=model.language_model,
862
+ )
863
+
864
+ else:
865
+ raise TypeError("The model must be Gemma3ForConditionalGeneration.")
866
+
867
+
697
868
  def apply_liger_kernel_to_paligemma(
698
869
  rope: bool = True,
699
870
  cross_entropy: bool = False,
@@ -1152,6 +1323,8 @@ def apply_liger_kernel_to_olmo2(
1152
1323
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1153
1324
  "gemma": apply_liger_kernel_to_gemma,
1154
1325
  "gemma2": apply_liger_kernel_to_gemma2,
1326
+ "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
+ "gemma3": apply_liger_kernel_to_gemma3,
1155
1328
  "llama": apply_liger_kernel_to_llama,
1156
1329
  "llava": apply_liger_kernel_to_llava,
1157
1330
  "granite": apply_liger_kernel_to_granite,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250402212634
3
+ Version: 0.5.6.dev20250403001329
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -311,6 +311,8 @@ loss.backward()
311
311
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
312
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
313
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
317
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
318
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -33,7 +33,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
33
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
34
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
35
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
36
- liger_kernel/transformers/__init__.py,sha256=t70gqygxH63iz-B0MOdZx4AEgA8MfqU1G7N6dvIneCY,2618
36
+ liger_kernel/transformers/__init__.py,sha256=23EWILU5JpoZjEx_yZCQMe6lprEsBJGHVt5GQBbwptg,2811
37
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
38
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
39
39
  liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
@@ -41,11 +41,12 @@ liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdV
41
41
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
42
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
43
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
44
+ liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
44
45
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
45
46
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
46
47
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
47
48
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
48
- liger_kernel/transformers/monkey_patch.py,sha256=95afvIrZA9xSWLNIJspBLbz8lxv2Y5gfZke7MyqoOX8,56965
49
+ liger_kernel/transformers/monkey_patch.py,sha256=QpfNU7MmVDGlBWIZ2RLTSyh0vuZ-si7H37SL-qOliUs,64393
49
50
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
50
51
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
51
52
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -56,6 +57,7 @@ liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1w
56
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
58
  liger_kernel/transformers/model/gemma.py,sha256=7cBTljzh-8_ACBhYl6NUfj5_ux92YRlmnAU5gfDAQAI,9312
58
59
  liger_kernel/transformers/model/gemma2.py,sha256=X0FOIhvFlTrmWI7Ws06wUkutgHW3lWtLOnnHp1NgZ3A,10403
60
+ liger_kernel/transformers/model/gemma3.py,sha256=PjAfFtupT9EW0sb57Hx8UJXcnvq9HFgNndeAE4EqyPw,16086
59
61
  liger_kernel/transformers/model/llama.py,sha256=d9rBaK8e8RSMCFHdgom9ZHuXOlnh6U_o-GkAFGRNGOY,9989
60
62
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
61
63
  liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
@@ -72,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
72
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
73
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
74
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
75
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
76
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/METADATA,sha256=PaFO566AhWjPHX3kn2S83vBHlK0N6LgyYjXL8SvH2qs,22959
77
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
78
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
79
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
80
- liger_kernel_nightly-0.5.5.dev20250402212634.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/METADATA,sha256=XTYWQ-SEGTr7X52X8TIIceAeNhIYMf3lzRf7LXP1vHM,23297
79
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/RECORD,,