liger-kernel-nightly 0.5.5.dev20250318183047__py3-none-any.whl → 0.5.5.dev20250320214749__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

@@ -21,6 +21,190 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
21
21
  logger = logging.get_logger(__name__)
22
22
 
23
23
 
24
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
+ def lce_forward_deprecated(
27
+ self,
28
+ input_ids: torch.LongTensor = None,
29
+ pixel_values: torch.FloatTensor = None,
30
+ attention_mask: Optional[torch.Tensor] = None,
31
+ position_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
33
+ token_type_ids: Optional[torch.LongTensor] = None,
34
+ cache_position: Optional[torch.LongTensor] = None,
35
+ inputs_embeds: Optional[torch.FloatTensor] = None,
36
+ labels: Optional[torch.LongTensor] = None,
37
+ use_cache: Optional[bool] = None,
38
+ output_attentions: Optional[bool] = None,
39
+ output_hidden_states: Optional[bool] = None,
40
+ return_dict: Optional[bool] = None,
41
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
42
+ r"""
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from PIL import Image
55
+ >>> import requests
56
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
57
+
58
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
59
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
60
+
61
+ >>> prompt = "answer en Where is the cow standing?"
62
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
63
+ >>> image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
66
+
67
+ >>> # Generate
68
+ >>> generate_ids = model.generate(**inputs, max_length=30)
69
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70
+ "answer en Where is the cow standing?\nbeach"
71
+ ```"""
72
+
73
+ if (input_ids is None) ^ (inputs_embeds is not None):
74
+ raise ValueError(
75
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
76
+ )
77
+
78
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79
+ output_hidden_states = (
80
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
+ )
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ # the attention mask is turned 4d after, we keep track of the original one
85
+ input_attention_mask = attention_mask
86
+
87
+ if inputs_embeds is None:
88
+ # 1. Extra the input embeddings
89
+ inputs_embeds = self.get_input_embeddings()(input_ids)
90
+
91
+ # 2. Merge text and images
92
+ if pixel_values is not None and input_ids.shape[1] != 1:
93
+ image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
94
+ selected_image_feature = image_outputs.last_hidden_state
95
+ image_features = self.multi_modal_projector(selected_image_feature)
96
+
97
+ if cache_position is None:
98
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
99
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
100
+ image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
101
+ )
102
+
103
+ else:
104
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
105
+ # generation with cache
106
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
107
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
108
+ # that are set to 0
109
+ # TODO @molbap this will only work for dynamic cache.
110
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
111
+
112
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
113
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
114
+
115
+ # Get the target length
116
+ target_seqlen = cache_position[-1] + 1
117
+ extended_attention_mask = torch.ones(
118
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
119
+ dtype=attention_mask.dtype,
120
+ device=attention_mask.device,
121
+ )
122
+ # Filter out only the tokens that can be un-attended, this can happen
123
+ # if one uses PaliGemma+ Fused modules where the cache on the
124
+ # first iteration is already big enough, or if one passes custom cache
125
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
126
+ new_batch_index = batch_index[valid_indices]
127
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
128
+
129
+ # Zero-out the places where we don't need to attend
130
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
131
+
132
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
133
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
134
+
135
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
136
+ outputs = self.language_model.model(
137
+ attention_mask=attention_mask,
138
+ position_ids=position_ids,
139
+ past_key_values=past_key_values,
140
+ inputs_embeds=inputs_embeds,
141
+ use_cache=use_cache,
142
+ output_attentions=output_attentions,
143
+ output_hidden_states=output_hidden_states,
144
+ return_dict=return_dict,
145
+ cache_position=cache_position,
146
+ )
147
+
148
+ hidden_states = outputs[0]
149
+
150
+ loss = None
151
+ logits = None
152
+
153
+ if self.training and (labels is not None):
154
+ shift_hidden_states = hidden_states[..., :-1, :]
155
+ shift_labels = labels[..., 1:]
156
+
157
+ hidden_device = shift_hidden_states.device
158
+
159
+ if attention_mask is not None:
160
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
161
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
162
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
163
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
164
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
165
+ else:
166
+ shift_hidden_states = shift_hidden_states.contiguous()
167
+ shift_labels = shift_labels.contiguous()
168
+
169
+ # Flatten hidden state
170
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
171
+ shift_labels = shift_labels.view(-1).to(hidden_device)
172
+
173
+ lce = LigerFusedLinearCrossEntropyLoss()
174
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
175
+
176
+ else:
177
+ logits = self.language_model.lm_head(hidden_states)
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :]
180
+ shift_labels = labels[..., 1:]
181
+ if input_attention_mask is not None:
182
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
183
+ shift_attention_mask = input_attention_mask[..., 1:]
184
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
185
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
186
+ else:
187
+ shift_logits = shift_logits.contiguous()
188
+ shift_labels = shift_labels.contiguous()
189
+ # Flatten the tokens
190
+ loss_fct = CrossEntropyLoss()
191
+
192
+ flat_logits = shift_logits.view(-1, self.config.vocab_size)
193
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
194
+ loss = loss_fct(flat_logits, flat_labels)
195
+ if not return_dict:
196
+ output = (logits,) + outputs[1:]
197
+ return (loss,) + output if loss is not None else output
198
+
199
+ return PaliGemmaCausalLMOutputWithPast(
200
+ loss=loss,
201
+ logits=logits,
202
+ past_key_values=outputs.past_key_values,
203
+ hidden_states=outputs.hidden_states,
204
+ attentions=outputs.attentions,
205
+ )
206
+
207
+
24
208
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25
209
  @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
26
210
  @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -631,6 +631,7 @@ def apply_liger_kernel_to_paligemma(
631
631
 
632
632
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
633
633
 
634
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
634
635
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
635
636
  from transformers.models.paligemma import modeling_paligemma
636
637
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
@@ -639,6 +640,7 @@ def apply_liger_kernel_to_paligemma(
639
640
  from transformers.models.siglip.modeling_siglip import SiglipVisionModel
640
641
 
641
642
  from liger_kernel.transformers.model.paligemma import lce_forward
643
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
642
644
 
643
645
  # The vision_tower is a SiglipVisionModel
644
646
  if layer_norm:
@@ -647,13 +649,22 @@ def apply_liger_kernel_to_paligemma(
647
649
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
648
650
  # The multi_modal_projector is Linear, nothing to do
649
651
 
650
- # The language_model is Gemma2ForCausalLM
651
- apply_liger_kernel_to_gemma2(rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, geglu=geglu)
652
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
653
+ apply_liger_kernel_to_gemma(
654
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
655
+ )
656
+ apply_liger_kernel_to_gemma2(
657
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
658
+ )
652
659
  # Handle loss function
653
660
  if cross_entropy:
654
661
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
655
662
  if fused_linear_cross_entropy:
656
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
663
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
664
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
665
+ else: # if version < 4.46.1
666
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
667
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
657
668
 
658
669
  if model is not None:
659
670
  # The model instance already exists, so we need to additionally patch the
@@ -672,16 +683,31 @@ def apply_liger_kernel_to_paligemma(
672
683
  _patch_layer_norm_module(layer.layer_norm1)
673
684
  _patch_layer_norm_module(layer.layer_norm2)
674
685
 
675
- language_model: Gemma2ForCausalLM = model.language_model
676
-
677
- apply_liger_kernel_to_gemma2(
678
- rope=rope,
679
- cross_entropy=False,
680
- fused_linear_cross_entropy=False,
681
- rms_norm=rms_norm,
682
- geglu=geglu,
683
- model=language_model,
684
- )
686
+ language_model = model.language_model
687
+
688
+ if isinstance(language_model, GemmaForCausalLM):
689
+ apply_liger_kernel_to_gemma(
690
+ rope=rope,
691
+ cross_entropy=False,
692
+ fused_linear_cross_entropy=False,
693
+ rms_norm=rms_norm,
694
+ geglu=geglu,
695
+ model=language_model,
696
+ )
697
+
698
+ elif isinstance(language_model, Gemma2ForCausalLM):
699
+ apply_liger_kernel_to_gemma2(
700
+ rope=rope,
701
+ cross_entropy=False,
702
+ fused_linear_cross_entropy=False,
703
+ rms_norm=rms_norm,
704
+ geglu=geglu,
705
+ model=language_model,
706
+ )
707
+ else:
708
+ raise TypeError(
709
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
710
+ )
685
711
 
686
712
 
687
713
  def apply_liger_kernel_to_qwen2(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250318183047
3
+ Version: 0.5.5.dev20250320214749
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -43,7 +43,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
43
43
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
44
44
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
45
45
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
46
- liger_kernel/transformers/monkey_patch.py,sha256=1Vzt_8UUMgO4t1ui7fNkKMcDfnWoCZfe9iyqeYSbe1w,50851
46
+ liger_kernel/transformers/monkey_patch.py,sha256=qRCgchODu6AuO8la6uAnrDEA-sSP9ADt8IOp4kl-Dd0,52053
47
47
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
48
48
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
49
49
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -60,7 +60,7 @@ liger_kernel/transformers/model/mistral.py,sha256=o7tyl1sPWPfZwwrBLRlryHlSI8I55v
60
60
  liger_kernel/transformers/model/mixtral.py,sha256=T0ITv2-PkR8VErVOVUizoS4EzjmARyR7GFh0tXDB_i4,11089
61
61
  liger_kernel/transformers/model/mllama.py,sha256=RCKtwnGOMFYIbtt1zUQ15Cyv4eNpHkTWcgkmG2EEs2I,10804
62
62
  liger_kernel/transformers/model/olmo2.py,sha256=5M8kczp4D-jvbjcV7cKATIJGF34xd-Rs-PPdKZWSIlY,4685
63
- liger_kernel/transformers/model/paligemma.py,sha256=C_Pb1qqxZl0J0fyXlwp1jTwNXckK9xuoSLHXy3rkWsE,10298
63
+ liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
64
64
  liger_kernel/transformers/model/phi3.py,sha256=NmU2DuU1Huwha6K7YSsJCnvQfUovTTGlsfBZhbx0UoI,9951
65
65
  liger_kernel/transformers/model/qwen2.py,sha256=t7NotBHoebsPqNSxwaf9DXTg8jxgB5BdunSGqYOE0hQ,9240
66
66
  liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
@@ -69,9 +69,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
69
69
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
70
70
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
71
71
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
72
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
73
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/METADATA,sha256=iXbBoxaUi6eIZIh18U5BHGauA2Ol0b_GcVuZKfWtnxE,22832
74
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
75
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
76
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
77
- liger_kernel_nightly-0.5.5.dev20250318183047.dist-info/RECORD,,
72
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
73
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/METADATA,sha256=WqbzHO3j_NRFdVkkvIfjevIYWO1ojp9D4NAV6hkIRV4,22832
74
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
75
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
76
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
77
+ liger_kernel_nightly-0.5.5.dev20250320214749.dist-info/RECORD,,