liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +65 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +43 -13
- liger_kernel/ops/geglu.py +2 -1
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +7 -2
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +27 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +29 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +594 -19
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/utils.py +25 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +4 -1
- liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.utils import can_return_tuple
|
|
9
|
+
|
|
10
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen3VLCausalLMOutputWithPast
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@can_return_tuple
|
|
16
|
+
def lce_forward(
|
|
17
|
+
self,
|
|
18
|
+
input_ids: torch.LongTensor = None,
|
|
19
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
20
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
21
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
22
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
23
|
+
labels: Optional[torch.LongTensor] = None,
|
|
24
|
+
use_cache: Optional[bool] = None,
|
|
25
|
+
output_attentions: Optional[bool] = None,
|
|
26
|
+
output_hidden_states: Optional[bool] = None,
|
|
27
|
+
return_dict: Optional[bool] = None,
|
|
28
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
29
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
30
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
31
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
32
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
33
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
34
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
35
|
+
skip_logits: Optional[bool] = None,
|
|
36
|
+
**kwargs,
|
|
37
|
+
) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]:
|
|
38
|
+
"""
|
|
39
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
40
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
41
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
42
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
43
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
|
44
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
|
45
|
+
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
|
|
46
|
+
[`Qwen2_5_VLImageProcessor`] for processing videos.
|
|
47
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
48
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
49
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
50
|
+
The temporal, height and width of feature shape of each video in LLM.
|
|
51
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
52
|
+
The rope index difference between sequence length and multimodal rope.
|
|
53
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
|
54
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
>>> from PIL import Image
|
|
58
|
+
>>> import requests
|
|
59
|
+
>>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
|
60
|
+
>>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL")
|
|
61
|
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
|
|
62
|
+
>>> messages = [
|
|
63
|
+
{
|
|
64
|
+
"role": "user",
|
|
65
|
+
"content": [
|
|
66
|
+
{"type": "image"},
|
|
67
|
+
{"type": "text", "text": "What is shown in this image?"},
|
|
68
|
+
],
|
|
69
|
+
},
|
|
70
|
+
]
|
|
71
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
72
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
73
|
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
74
|
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
75
|
+
>>> # Generate
|
|
76
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
77
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
78
|
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
79
|
+
```"""
|
|
80
|
+
|
|
81
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
82
|
+
output_hidden_states = (
|
|
83
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
84
|
+
)
|
|
85
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
86
|
+
|
|
87
|
+
outputs = self.model(
|
|
88
|
+
input_ids=input_ids,
|
|
89
|
+
pixel_values=pixel_values,
|
|
90
|
+
pixel_values_videos=pixel_values_videos,
|
|
91
|
+
image_grid_thw=image_grid_thw,
|
|
92
|
+
video_grid_thw=video_grid_thw,
|
|
93
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
94
|
+
position_ids=position_ids,
|
|
95
|
+
attention_mask=attention_mask,
|
|
96
|
+
past_key_values=past_key_values,
|
|
97
|
+
inputs_embeds=inputs_embeds,
|
|
98
|
+
use_cache=use_cache,
|
|
99
|
+
output_attentions=output_attentions,
|
|
100
|
+
output_hidden_states=output_hidden_states,
|
|
101
|
+
return_dict=return_dict,
|
|
102
|
+
cache_position=cache_position,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
hidden_states = outputs[0]
|
|
107
|
+
|
|
108
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
109
|
+
loss = None
|
|
110
|
+
logits = None
|
|
111
|
+
token_accuracy = None
|
|
112
|
+
|
|
113
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
114
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
115
|
+
|
|
116
|
+
if skip_logits is None:
|
|
117
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
118
|
+
|
|
119
|
+
if skip_logits:
|
|
120
|
+
result = LigerForCausalLMLoss(
|
|
121
|
+
hidden_states=hidden_states,
|
|
122
|
+
lm_head_weight=self.lm_head.weight,
|
|
123
|
+
labels=labels,
|
|
124
|
+
shift_labels=shift_labels,
|
|
125
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
126
|
+
**kwargs,
|
|
127
|
+
)
|
|
128
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
129
|
+
else:
|
|
130
|
+
logits = self.lm_head(hidden_states)
|
|
131
|
+
|
|
132
|
+
loss = None
|
|
133
|
+
if labels is not None:
|
|
134
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
135
|
+
|
|
136
|
+
if not return_dict:
|
|
137
|
+
output = (logits,) + outputs[1:]
|
|
138
|
+
output = (loss,) + output if loss is not None else output
|
|
139
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
140
|
+
return output
|
|
141
|
+
|
|
142
|
+
return LigerQwen3VLCausalLMOutputWithPast(
|
|
143
|
+
loss=loss,
|
|
144
|
+
logits=logits,
|
|
145
|
+
past_key_values=outputs.past_key_values,
|
|
146
|
+
hidden_states=outputs.hidden_states,
|
|
147
|
+
attentions=outputs.attentions,
|
|
148
|
+
rope_deltas=outputs.rope_deltas,
|
|
149
|
+
token_accuracy=token_accuracy,
|
|
150
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func
|
|
9
|
+
from transformers.utils import can_return_tuple
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
13
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen3VLMoeCausalLMOutputWithPast
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@can_return_tuple
|
|
17
|
+
def lce_forward(
|
|
18
|
+
self,
|
|
19
|
+
input_ids: torch.LongTensor = None,
|
|
20
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
21
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
22
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
23
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
24
|
+
labels: Optional[torch.LongTensor] = None,
|
|
25
|
+
use_cache: Optional[bool] = None,
|
|
26
|
+
output_attentions: Optional[bool] = None,
|
|
27
|
+
output_hidden_states: Optional[bool] = None,
|
|
28
|
+
return_dict: Optional[bool] = None,
|
|
29
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
30
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
31
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
32
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
33
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
34
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
35
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
36
|
+
skip_logits: Optional[bool] = None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]:
|
|
39
|
+
"""
|
|
40
|
+
Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
44
|
+
output_hidden_states = (
|
|
45
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
46
|
+
)
|
|
47
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
48
|
+
|
|
49
|
+
outputs = self.model(
|
|
50
|
+
input_ids=input_ids,
|
|
51
|
+
pixel_values=pixel_values,
|
|
52
|
+
pixel_values_videos=pixel_values_videos,
|
|
53
|
+
image_grid_thw=image_grid_thw,
|
|
54
|
+
video_grid_thw=video_grid_thw,
|
|
55
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
56
|
+
position_ids=position_ids,
|
|
57
|
+
attention_mask=attention_mask,
|
|
58
|
+
past_key_values=past_key_values,
|
|
59
|
+
inputs_embeds=inputs_embeds,
|
|
60
|
+
use_cache=use_cache,
|
|
61
|
+
output_attentions=output_attentions,
|
|
62
|
+
output_hidden_states=output_hidden_states,
|
|
63
|
+
return_dict=return_dict,
|
|
64
|
+
cache_position=cache_position,
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
hidden_states = outputs[0]
|
|
69
|
+
|
|
70
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
71
|
+
loss = None
|
|
72
|
+
logits = None
|
|
73
|
+
token_accuracy = None
|
|
74
|
+
|
|
75
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
76
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
77
|
+
|
|
78
|
+
if skip_logits is None:
|
|
79
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
80
|
+
|
|
81
|
+
if skip_logits:
|
|
82
|
+
result = LigerForCausalLMLoss(
|
|
83
|
+
hidden_states=hidden_states,
|
|
84
|
+
lm_head_weight=self.lm_head.weight,
|
|
85
|
+
labels=labels,
|
|
86
|
+
shift_labels=shift_labels,
|
|
87
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
88
|
+
**kwargs,
|
|
89
|
+
)
|
|
90
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
91
|
+
else:
|
|
92
|
+
logits = self.lm_head(hidden_states)
|
|
93
|
+
|
|
94
|
+
if labels is not None:
|
|
95
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
96
|
+
|
|
97
|
+
# Compute auxiliary load-balancing loss for MoE when requested
|
|
98
|
+
aux_loss = None
|
|
99
|
+
if kwargs.get("output_router_logits", False):
|
|
100
|
+
aux_loss = load_balancing_loss_func(
|
|
101
|
+
outputs.router_logits,
|
|
102
|
+
self.config.text_config.num_experts,
|
|
103
|
+
self.config.text_config.num_experts_per_tok,
|
|
104
|
+
attention_mask,
|
|
105
|
+
)
|
|
106
|
+
# If we computed training loss, add the scaled aux loss to it
|
|
107
|
+
if loss is not None and aux_loss is not None:
|
|
108
|
+
loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device)
|
|
109
|
+
|
|
110
|
+
if not return_dict:
|
|
111
|
+
output = (logits,) + outputs[1:]
|
|
112
|
+
output = (loss,) + output if loss is not None else output
|
|
113
|
+
output = output + (aux_loss,) if aux_loss is not None else output
|
|
114
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
115
|
+
return output
|
|
116
|
+
|
|
117
|
+
return LigerQwen3VLMoeCausalLMOutputWithPast(
|
|
118
|
+
loss=loss,
|
|
119
|
+
logits=logits,
|
|
120
|
+
past_key_values=outputs.past_key_values,
|
|
121
|
+
hidden_states=outputs.hidden_states,
|
|
122
|
+
attentions=outputs.attentions,
|
|
123
|
+
rope_deltas=outputs.rope_deltas,
|
|
124
|
+
aux_loss=aux_loss,
|
|
125
|
+
token_accuracy=token_accuracy,
|
|
126
|
+
)
|
|
@@ -7,11 +7,12 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
12
11
|
|
|
13
12
|
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
14
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
15
16
|
from liger_kernel.utils import PEFT_AVAILABLE
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
@@ -38,7 +39,7 @@ def lce_forward(
|
|
|
38
39
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
39
40
|
skip_logits: Optional[bool] = None,
|
|
40
41
|
**kwargs,
|
|
41
|
-
) -> Union[Tuple,
|
|
42
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
42
43
|
r"""
|
|
43
44
|
Args:
|
|
44
45
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -101,6 +102,8 @@ def lce_forward(
|
|
|
101
102
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
102
103
|
logits = None
|
|
103
104
|
loss = None
|
|
105
|
+
token_accuracy = None
|
|
106
|
+
|
|
104
107
|
# if in training mode, don't materialize logits
|
|
105
108
|
if skip_logits and labels is None and shift_labels is None:
|
|
106
109
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -109,8 +112,9 @@ def lce_forward(
|
|
|
109
112
|
# By default, if in training mode, don't materialize logits
|
|
110
113
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
111
114
|
|
|
115
|
+
# Compute loss
|
|
112
116
|
if skip_logits:
|
|
113
|
-
|
|
117
|
+
result = lce_maybe_trainable_lm_head(
|
|
114
118
|
self,
|
|
115
119
|
hidden_states=kept_hidden_states,
|
|
116
120
|
hidden_size=self.config.hidden_size,
|
|
@@ -118,6 +122,7 @@ def lce_forward(
|
|
|
118
122
|
shift_labels=shift_labels,
|
|
119
123
|
**kwargs,
|
|
120
124
|
)
|
|
125
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
121
126
|
|
|
122
127
|
else:
|
|
123
128
|
logits = self.lm_head(kept_hidden_states)
|
|
@@ -131,15 +136,19 @@ def lce_forward(
|
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if not return_dict:
|
|
134
|
-
|
|
135
|
-
|
|
139
|
+
output_tuple = (logits,) + outputs[1:]
|
|
140
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
141
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
142
|
+
return output
|
|
136
143
|
|
|
137
|
-
|
|
144
|
+
# Return custom output class with token_accuracy field
|
|
145
|
+
return LigerCausalLMOutputWithPast(
|
|
138
146
|
loss=loss,
|
|
139
147
|
logits=logits,
|
|
140
148
|
past_key_values=outputs.past_key_values,
|
|
141
149
|
hidden_states=outputs.hidden_states,
|
|
142
150
|
attentions=outputs.attentions,
|
|
151
|
+
token_accuracy=token_accuracy,
|
|
143
152
|
)
|
|
144
153
|
|
|
145
154
|
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast
|
|
8
|
+
from transformers.processing_utils import Unpack
|
|
9
|
+
from transformers.utils.generic import can_return_tuple
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from transformers.cache_utils import Cache
|
|
15
|
+
from transformers.utils.generic import TransformersKwargs
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Forward adapted to enable fused Linear + CE without materializing logits.
|
|
19
|
+
# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA).
|
|
20
|
+
@can_return_tuple
|
|
21
|
+
def lce_forward(
|
|
22
|
+
self,
|
|
23
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
24
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
25
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
26
|
+
past_key_values: Optional["Cache"] = None,
|
|
27
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
28
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
29
|
+
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
30
|
+
image_hidden_states: Optional[torch.FloatTensor] = None,
|
|
31
|
+
labels: Optional[torch.LongTensor] = None,
|
|
32
|
+
use_cache: Optional[bool] = None,
|
|
33
|
+
output_attentions: Optional[bool] = None,
|
|
34
|
+
output_hidden_states: Optional[bool] = None,
|
|
35
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
38
|
+
skip_logits: Optional[bool] = None, # Added argument for liger-kernel
|
|
39
|
+
**lm_kwargs: Unpack["TransformersKwargs"], # renamed from kwargs
|
|
40
|
+
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
|
|
41
|
+
r"""
|
|
42
|
+
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
|
43
|
+
Mask to avoid performing attention on padding pixel indices.
|
|
44
|
+
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
45
|
+
The hidden states of the image encoder after modality projection.
|
|
46
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
47
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
48
|
+
config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
|
|
49
|
+
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
>>> import requests
|
|
55
|
+
>>> import torch
|
|
56
|
+
>>> from PIL import Image
|
|
57
|
+
>>> from io import BytesIO
|
|
58
|
+
|
|
59
|
+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
60
|
+
>>> from transformers.image_utils import load_image
|
|
61
|
+
|
|
62
|
+
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
|
63
|
+
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
|
64
|
+
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
|
65
|
+
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
|
66
|
+
|
|
67
|
+
>>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
|
68
|
+
>>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
|
69
|
+
|
|
70
|
+
>>> # Create inputs
|
|
71
|
+
>>> messages = [
|
|
72
|
+
... {
|
|
73
|
+
... "role": "user",
|
|
74
|
+
... "content": [
|
|
75
|
+
... {"type": "video", "path": path/to/video},
|
|
76
|
+
... {"type": "text", "text": "What is happening in this video?"},
|
|
77
|
+
... ]
|
|
78
|
+
... }
|
|
79
|
+
... ]
|
|
80
|
+
|
|
81
|
+
>>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
|
|
82
|
+
|
|
83
|
+
>>> # Generate
|
|
84
|
+
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
|
85
|
+
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
86
|
+
|
|
87
|
+
>>> print(generated_texts)
|
|
88
|
+
```"""
|
|
89
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
90
|
+
output_hidden_states = (
|
|
91
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
92
|
+
)
|
|
93
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
94
|
+
|
|
95
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
96
|
+
outputs = self.model(
|
|
97
|
+
input_ids=input_ids,
|
|
98
|
+
attention_mask=attention_mask,
|
|
99
|
+
position_ids=position_ids,
|
|
100
|
+
past_key_values=past_key_values,
|
|
101
|
+
inputs_embeds=inputs_embeds,
|
|
102
|
+
pixel_values=pixel_values,
|
|
103
|
+
pixel_attention_mask=pixel_attention_mask,
|
|
104
|
+
image_hidden_states=image_hidden_states,
|
|
105
|
+
use_cache=use_cache,
|
|
106
|
+
output_attentions=output_attentions,
|
|
107
|
+
output_hidden_states=output_hidden_states,
|
|
108
|
+
cache_position=cache_position,
|
|
109
|
+
return_dict=True,
|
|
110
|
+
**lm_kwargs,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Copied from llava.py
|
|
114
|
+
hidden_states = outputs[0]
|
|
115
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
116
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
117
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
118
|
+
|
|
119
|
+
shift_labels = lm_kwargs.pop("shift_labels", None)
|
|
120
|
+
logits = None
|
|
121
|
+
loss = None
|
|
122
|
+
|
|
123
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
124
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
125
|
+
|
|
126
|
+
if skip_logits is None:
|
|
127
|
+
# By default, if in training mode, don't materialize logits
|
|
128
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
129
|
+
|
|
130
|
+
if skip_logits:
|
|
131
|
+
loss = LigerForCausalLMLoss(
|
|
132
|
+
hidden_states=kept_hidden_states,
|
|
133
|
+
lm_head_weight=self.lm_head.weight,
|
|
134
|
+
labels=labels,
|
|
135
|
+
shift_labels=shift_labels,
|
|
136
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
137
|
+
**lm_kwargs,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
logits = self.lm_head(kept_hidden_states)
|
|
142
|
+
if labels is not None or shift_labels is not None:
|
|
143
|
+
loss = self.loss_function(
|
|
144
|
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if not return_dict:
|
|
148
|
+
output = (logits,) + outputs[1:]
|
|
149
|
+
return (loss,) + output if loss is not None else output
|
|
150
|
+
|
|
151
|
+
return SmolVLMCausalLMOutputWithPast(
|
|
152
|
+
loss=loss,
|
|
153
|
+
logits=logits,
|
|
154
|
+
past_key_values=outputs.past_key_values,
|
|
155
|
+
hidden_states=outputs.hidden_states,
|
|
156
|
+
attentions=outputs.attentions,
|
|
157
|
+
image_hidden_states=outputs.image_hidden_states,
|
|
158
|
+
)
|