liger-kernel-nightly 0.5.10.dev20250704061237__py3-none-any.whl → 0.5.10.dev20250707212543__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/gemma3.py +49 -97
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250704061237.dist-info → liger_kernel_nightly-0.5.10.dev20250707212543.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
|
|
1
|
-
from typing import List
|
2
1
|
from typing import Optional
|
3
2
|
from typing import Tuple
|
4
3
|
from typing import Union
|
@@ -10,9 +9,7 @@ from transformers.cache_utils import Cache
|
|
10
9
|
from transformers.cache_utils import HybridCache
|
11
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
12
11
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
|
13
|
-
from transformers.utils import is_torchdynamo_compiling
|
14
12
|
from transformers.utils import logging
|
15
|
-
from transformers.utils.deprecation import deprecate_kwarg
|
16
13
|
|
17
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
18
15
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
@@ -20,7 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
20
17
|
logger = logging.get_logger(__name__)
|
21
18
|
|
22
19
|
|
23
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
24
20
|
def causal_forward(
|
25
21
|
self,
|
26
22
|
input_ids: torch.LongTensor = None,
|
@@ -139,14 +135,13 @@ def causal_forward(
|
|
139
135
|
)
|
140
136
|
|
141
137
|
|
142
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
143
138
|
def multimodal_forward(
|
144
139
|
self,
|
145
140
|
input_ids: torch.LongTensor = None,
|
146
141
|
pixel_values: torch.FloatTensor = None,
|
147
142
|
attention_mask: Optional[torch.Tensor] = None,
|
148
143
|
position_ids: Optional[torch.LongTensor] = None,
|
149
|
-
past_key_values: Optional[Union[
|
144
|
+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
150
145
|
token_type_ids: Optional[torch.LongTensor] = None,
|
151
146
|
cache_position: Optional[torch.LongTensor] = None,
|
152
147
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
@@ -158,21 +153,12 @@ def multimodal_forward(
|
|
158
153
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
159
154
|
skip_logits: Optional[bool] = None,
|
160
155
|
**lm_kwargs,
|
161
|
-
) -> Union[
|
156
|
+
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
162
157
|
r"""
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
169
|
-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
170
|
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
171
|
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
172
|
-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
173
|
-
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
174
|
-
|
175
|
-
Returns:
|
158
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
159
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
160
|
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
161
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
176
162
|
|
177
163
|
Example:
|
178
164
|
|
@@ -181,23 +167,37 @@ def multimodal_forward(
|
|
181
167
|
>>> import requests
|
182
168
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
183
169
|
|
184
|
-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/
|
185
|
-
>>> processor = AutoProcessor.from_pretrained("google/
|
186
|
-
|
187
|
-
>>>
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
170
|
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
171
|
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
172
|
+
|
173
|
+
>>> messages = [
|
174
|
+
... {
|
175
|
+
... "role": "system",
|
176
|
+
... "content": [
|
177
|
+
... {"type": "text", "text": "You are a helpful assistant."}
|
178
|
+
... ]
|
179
|
+
... },
|
180
|
+
... {
|
181
|
+
... "role": "user", "content": [
|
182
|
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
183
|
+
... {"type": "text", "text": "Where is the cat standing?"},
|
184
|
+
... ]
|
185
|
+
... },
|
186
|
+
... ]
|
187
|
+
|
188
|
+
>>> inputs = processor.apply_chat_template(
|
189
|
+
... messages,
|
190
|
+
... tokenize=True,
|
191
|
+
... return_dict=True,
|
192
|
+
... return_tensors="pt",
|
193
|
+
... add_generation_prompt=True
|
194
|
+
... )
|
193
195
|
>>> # Generate
|
194
|
-
>>> generate_ids = model.generate(**inputs
|
196
|
+
>>> generate_ids = model.generate(**inputs)
|
195
197
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
196
|
-
"
|
197
|
-
```
|
198
|
-
|
199
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
200
|
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
198
|
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
199
|
+
```
|
200
|
+
"""
|
201
201
|
|
202
202
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
203
203
|
output_hidden_states = (
|
@@ -205,79 +205,30 @@ def multimodal_forward(
|
|
205
205
|
)
|
206
206
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
207
207
|
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
llm_input_ids = input_ids.clone()
|
214
|
-
llm_input_ids[special_image_mask] = 0
|
215
|
-
else:
|
216
|
-
llm_input_ids = input_ids
|
217
|
-
|
218
|
-
if inputs_embeds is None:
|
219
|
-
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
220
|
-
|
221
|
-
if cache_position is None:
|
222
|
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
223
|
-
cache_position = torch.arange(
|
224
|
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
225
|
-
)
|
226
|
-
|
227
|
-
if position_ids is None:
|
228
|
-
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
|
229
|
-
|
230
|
-
# Merge text and images
|
231
|
-
if pixel_values is not None:
|
232
|
-
image_features = self.get_image_features(pixel_values)
|
233
|
-
|
234
|
-
if input_ids is None:
|
235
|
-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
236
|
-
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
237
|
-
)
|
238
|
-
else:
|
239
|
-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
240
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
241
|
-
|
242
|
-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
243
|
-
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
244
|
-
raise ValueError(
|
245
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
246
|
-
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
247
|
-
"tokens from image embeddings."
|
248
|
-
)
|
249
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
250
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
251
|
-
|
252
|
-
# mask out pad-token-ids in labels for BC
|
253
|
-
if labels is not None and self.pad_token_id in labels:
|
254
|
-
logger.warning_once(
|
255
|
-
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
256
|
-
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
257
|
-
)
|
258
|
-
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
259
|
-
|
260
|
-
causal_mask = self._update_causal_mask(
|
261
|
-
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
262
|
-
)
|
263
|
-
outputs = self.language_model.model(
|
264
|
-
attention_mask=causal_mask,
|
208
|
+
outputs = self.model(
|
209
|
+
input_ids=input_ids,
|
210
|
+
pixel_values=pixel_values,
|
211
|
+
token_type_ids=token_type_ids,
|
212
|
+
attention_mask=attention_mask,
|
265
213
|
position_ids=position_ids,
|
266
214
|
past_key_values=past_key_values,
|
267
215
|
inputs_embeds=inputs_embeds,
|
268
216
|
use_cache=use_cache,
|
217
|
+
labels=labels,
|
269
218
|
output_attentions=output_attentions,
|
270
219
|
output_hidden_states=output_hidden_states,
|
271
220
|
return_dict=return_dict,
|
272
221
|
cache_position=cache_position,
|
273
|
-
logits_to_keep=logits_to_keep,
|
274
222
|
**lm_kwargs,
|
275
223
|
)
|
276
224
|
|
277
225
|
hidden_states = outputs[0]
|
226
|
+
|
227
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
228
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
229
|
+
|
278
230
|
loss = None
|
279
231
|
logits = None
|
280
|
-
|
281
232
|
if skip_logits and labels is None:
|
282
233
|
raise ValueError("skip_logits is True, but labels is None")
|
283
234
|
|
@@ -285,7 +236,7 @@ def multimodal_forward(
|
|
285
236
|
skip_logits = self.training and (labels is not None)
|
286
237
|
|
287
238
|
if skip_logits:
|
288
|
-
shift_hidden_states =
|
239
|
+
shift_hidden_states = kept_hidden_states[..., :-1, :]
|
289
240
|
shift_labels = labels[..., 1:]
|
290
241
|
|
291
242
|
hidden_device = shift_hidden_states.device
|
@@ -306,7 +257,7 @@ def multimodal_forward(
|
|
306
257
|
lce = LigerFusedLinearCrossEntropyLoss()
|
307
258
|
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
308
259
|
else:
|
309
|
-
logits = self.
|
260
|
+
logits = self.lm_head(kept_hidden_states)
|
310
261
|
if labels is not None:
|
311
262
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
312
263
|
logits = logits.float()
|
@@ -327,6 +278,7 @@ def multimodal_forward(
|
|
327
278
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
328
279
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
329
280
|
loss = loss_fct(flat_logits, flat_labels)
|
281
|
+
|
330
282
|
if not return_dict:
|
331
283
|
output = (logits,) + outputs[1:]
|
332
284
|
return (loss,) + output if loss is not None else output
|
@@ -337,5 +289,5 @@ def multimodal_forward(
|
|
337
289
|
past_key_values=outputs.past_key_values,
|
338
290
|
hidden_states=outputs.hidden_states,
|
339
291
|
attentions=outputs.attentions,
|
340
|
-
image_hidden_states=
|
292
|
+
image_hidden_states=outputs.image_hidden_states,
|
341
293
|
)
|
@@ -68,7 +68,7 @@ liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1w
|
|
68
68
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
69
69
|
liger_kernel/transformers/model/gemma.py,sha256=mNX-mIwV6jI4zfbrUHp0C468pOmjzsL7mjXipGt-eS0,10007
|
70
70
|
liger_kernel/transformers/model/gemma2.py,sha256=R_JFPyWTk7RyA7D05ZiIaNO5pX8gWcvfWf-6rdCRMxs,11296
|
71
|
-
liger_kernel/transformers/model/gemma3.py,sha256=
|
71
|
+
liger_kernel/transformers/model/gemma3.py,sha256=XbwoqOSPmtS0BPHgT8jZftTzplmiAicgBa6ocNcet8o,12800
|
72
72
|
liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
|
73
73
|
liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
|
74
74
|
liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
|
@@ -89,9 +89,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
89
89
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
90
90
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
91
91
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
95
|
-
liger_kernel_nightly-0.5.10.
|
96
|
-
liger_kernel_nightly-0.5.10.
|
97
|
-
liger_kernel_nightly-0.5.10.
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/METADATA,sha256=PppDWZi5ORp7NVsDSFuduKl0RAmLEz80PDCmuAKd9-g,24536
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
96
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
97
|
+
liger_kernel_nightly-0.5.10.dev20250707212543.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|