liger-kernel-nightly 0.5.10.dev20250704061237__py3-none-any.whl → 0.5.10.dev20250708114334__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- from typing import List
2
1
  from typing import Optional
3
2
  from typing import Tuple
4
3
  from typing import Union
@@ -10,9 +9,7 @@ from transformers.cache_utils import Cache
10
9
  from transformers.cache_utils import HybridCache
11
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
11
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
13
- from transformers.utils import is_torchdynamo_compiling
14
12
  from transformers.utils import logging
15
- from transformers.utils.deprecation import deprecate_kwarg
16
13
 
17
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -20,7 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
20
17
  logger = logging.get_logger(__name__)
21
18
 
22
19
 
23
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
24
20
  def causal_forward(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -139,14 +135,13 @@ def causal_forward(
139
135
  )
140
136
 
141
137
 
142
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
143
138
  def multimodal_forward(
144
139
  self,
145
140
  input_ids: torch.LongTensor = None,
146
141
  pixel_values: torch.FloatTensor = None,
147
142
  attention_mask: Optional[torch.Tensor] = None,
148
143
  position_ids: Optional[torch.LongTensor] = None,
149
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
144
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
150
145
  token_type_ids: Optional[torch.LongTensor] = None,
151
146
  cache_position: Optional[torch.LongTensor] = None,
152
147
  inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -158,21 +153,12 @@ def multimodal_forward(
158
153
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
154
  skip_logits: Optional[bool] = None,
160
155
  **lm_kwargs,
161
- ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
156
+ ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
162
157
  r"""
163
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
167
-
168
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
169
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
170
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
173
- This is useful when using packed tensor format (single dimension for batch and sequence length).
174
-
175
- Returns:
158
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
160
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
176
162
 
177
163
  Example:
178
164
 
@@ -181,23 +167,37 @@ def multimodal_forward(
181
167
  >>> import requests
182
168
  >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
183
169
 
184
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
185
- >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
186
-
187
- >>> prompt = "answer en Where is the cow standing?"
188
- >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
189
- >>> image = Image.open(requests.get(url, stream=True).raw)
190
-
191
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
192
-
170
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
171
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
172
+
173
+ >>> messages = [
174
+ ... {
175
+ ... "role": "system",
176
+ ... "content": [
177
+ ... {"type": "text", "text": "You are a helpful assistant."}
178
+ ... ]
179
+ ... },
180
+ ... {
181
+ ... "role": "user", "content": [
182
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
183
+ ... {"type": "text", "text": "Where is the cat standing?"},
184
+ ... ]
185
+ ... },
186
+ ... ]
187
+
188
+ >>> inputs = processor.apply_chat_template(
189
+ ... messages,
190
+ ... tokenize=True,
191
+ ... return_dict=True,
192
+ ... return_tensors="pt",
193
+ ... add_generation_prompt=True
194
+ ... )
193
195
  >>> # Generate
194
- >>> generate_ids = model.generate(**inputs, max_length=30)
196
+ >>> generate_ids = model.generate(**inputs)
195
197
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
- "answer en Where is the cow standing?\nbeach"
197
- ```"""
198
-
199
- if (input_ids is None) ^ (inputs_embeds is not None):
200
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
198
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
199
+ ```
200
+ """
201
201
 
202
202
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
203
203
  output_hidden_states = (
@@ -205,79 +205,30 @@ def multimodal_forward(
205
205
  )
206
206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
207
 
208
- is_training = token_type_ids is not None and labels is not None
209
-
210
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
211
- if input_ids is not None and self.config.image_token_index >= self.vocab_size:
212
- special_image_mask = input_ids == self.config.image_token_index
213
- llm_input_ids = input_ids.clone()
214
- llm_input_ids[special_image_mask] = 0
215
- else:
216
- llm_input_ids = input_ids
217
-
218
- if inputs_embeds is None:
219
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
220
-
221
- if cache_position is None:
222
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
223
- cache_position = torch.arange(
224
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
225
- )
226
-
227
- if position_ids is None:
228
- position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
229
-
230
- # Merge text and images
231
- if pixel_values is not None:
232
- image_features = self.get_image_features(pixel_values)
233
-
234
- if input_ids is None:
235
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
236
- torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
237
- )
238
- else:
239
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
240
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
241
-
242
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
243
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
244
- raise ValueError(
245
- f"Number of images does not match number of special image tokens in the input text. "
246
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
247
- "tokens from image embeddings."
248
- )
249
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
250
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
251
-
252
- # mask out pad-token-ids in labels for BC
253
- if labels is not None and self.pad_token_id in labels:
254
- logger.warning_once(
255
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
256
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
257
- )
258
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
259
-
260
- causal_mask = self._update_causal_mask(
261
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
262
- )
263
- outputs = self.language_model.model(
264
- attention_mask=causal_mask,
208
+ outputs = self.model(
209
+ input_ids=input_ids,
210
+ pixel_values=pixel_values,
211
+ token_type_ids=token_type_ids,
212
+ attention_mask=attention_mask,
265
213
  position_ids=position_ids,
266
214
  past_key_values=past_key_values,
267
215
  inputs_embeds=inputs_embeds,
268
216
  use_cache=use_cache,
217
+ labels=labels,
269
218
  output_attentions=output_attentions,
270
219
  output_hidden_states=output_hidden_states,
271
220
  return_dict=return_dict,
272
221
  cache_position=cache_position,
273
- logits_to_keep=logits_to_keep,
274
222
  **lm_kwargs,
275
223
  )
276
224
 
277
225
  hidden_states = outputs[0]
226
+
227
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
228
+ kept_hidden_states = hidden_states[:, slice_indices, :]
229
+
278
230
  loss = None
279
231
  logits = None
280
-
281
232
  if skip_logits and labels is None:
282
233
  raise ValueError("skip_logits is True, but labels is None")
283
234
 
@@ -285,7 +236,7 @@ def multimodal_forward(
285
236
  skip_logits = self.training and (labels is not None)
286
237
 
287
238
  if skip_logits:
288
- shift_hidden_states = hidden_states[..., :-1, :]
239
+ shift_hidden_states = kept_hidden_states[..., :-1, :]
289
240
  shift_labels = labels[..., 1:]
290
241
 
291
242
  hidden_device = shift_hidden_states.device
@@ -306,7 +257,7 @@ def multimodal_forward(
306
257
  lce = LigerFusedLinearCrossEntropyLoss()
307
258
  loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
308
259
  else:
309
- logits = self.language_model.lm_head(hidden_states)
260
+ logits = self.lm_head(kept_hidden_states)
310
261
  if labels is not None:
311
262
  # Upcast to float if we need to compute the loss to avoid potential precision issues
312
263
  logits = logits.float()
@@ -327,6 +278,7 @@ def multimodal_forward(
327
278
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
328
279
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
329
280
  loss = loss_fct(flat_logits, flat_labels)
281
+
330
282
  if not return_dict:
331
283
  output = (logits,) + outputs[1:]
332
284
  return (loss,) + output if loss is not None else output
@@ -337,5 +289,5 @@ def multimodal_forward(
337
289
  past_key_values=outputs.past_key_values,
338
290
  hidden_states=outputs.hidden_states,
339
291
  attentions=outputs.attentions,
340
- image_hidden_states=image_features if pixel_values is not None else None,
292
+ image_hidden_states=outputs.image_hidden_states,
341
293
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250704061237
3
+ Version: 0.5.10.dev20250708114334
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -68,7 +68,7 @@ liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1w
68
68
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
69
  liger_kernel/transformers/model/gemma.py,sha256=mNX-mIwV6jI4zfbrUHp0C468pOmjzsL7mjXipGt-eS0,10007
70
70
  liger_kernel/transformers/model/gemma2.py,sha256=R_JFPyWTk7RyA7D05ZiIaNO5pX8gWcvfWf-6rdCRMxs,11296
71
- liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
71
+ liger_kernel/transformers/model/gemma3.py,sha256=XbwoqOSPmtS0BPHgT8jZftTzplmiAicgBa6ocNcet8o,12800
72
72
  liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
73
73
  liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
74
74
  liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
@@ -89,9 +89,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
89
89
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
90
90
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
91
91
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
92
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
93
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/METADATA,sha256=YkNjPNalpBKear6X387VrD6wFnmWZYNZkghDYQ250DU,24536
94
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
95
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
96
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
97
- liger_kernel_nightly-0.5.10.dev20250704061237.dist-info/RECORD,,
92
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
93
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/METADATA,sha256=iIItIP0R_XQyEmKgWzaFHcuWd3udR53eN6YFFZpi7co,24536
94
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
95
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
96
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
97
+ liger_kernel_nightly-0.5.10.dev20250708114334.dist-info/RECORD,,