liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +59 -53
- liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
- liger_kernel/ops/layer_norm.py +4 -6
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/transformers/__init__.py +32 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +9 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +108 -0
- liger_kernel/transformers/model/gemma.py +2 -1
- liger_kernel/transformers/model/gemma2.py +8 -2
- liger_kernel/transformers/model/gemma3.py +27 -2
- liger_kernel/transformers/model/glm4.py +2 -1
- liger_kernel/transformers/model/glm4v.py +151 -0
- liger_kernel/transformers/model/glm4v_moe.py +153 -0
- liger_kernel/transformers/model/internvl.py +150 -0
- liger_kernel/transformers/model/llama.py +2 -1
- liger_kernel/transformers/model/llama4.py +2 -1
- liger_kernel/transformers/model/llava.py +6 -2
- liger_kernel/transformers/model/loss_utils.py +3 -0
- liger_kernel/transformers/model/mistral.py +2 -1
- liger_kernel/transformers/model/mixtral.py +8 -2
- liger_kernel/transformers/model/mllama.py +6 -3
- liger_kernel/transformers/model/olmo2.py +2 -1
- liger_kernel/transformers/model/paligemma.py +19 -0
- liger_kernel/transformers/model/phi3.py +10 -160
- liger_kernel/transformers/model/qwen2.py +2 -1
- liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
- liger_kernel/transformers/model/qwen2_vl.py +7 -2
- liger_kernel/transformers/model/qwen3.py +2 -1
- liger_kernel/transformers/model/qwen3_moe.py +8 -2
- liger_kernel/transformers/model/qwen3_next.py +134 -0
- liger_kernel/transformers/model/smollm3.py +2 -1
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +552 -23
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
|
@@ -5,131 +5,12 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.utils.deprecation import deprecate_kwarg
|
|
11
10
|
|
|
12
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
13
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
12
|
|
|
15
13
|
|
|
16
|
-
def lce_forward_deprecated(
|
|
17
|
-
self,
|
|
18
|
-
input_ids: torch.LongTensor = None,
|
|
19
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
20
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
21
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
22
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
23
|
-
labels: Optional[torch.LongTensor] = None,
|
|
24
|
-
use_cache: Optional[bool] = None,
|
|
25
|
-
output_attentions: Optional[bool] = None,
|
|
26
|
-
output_hidden_states: Optional[bool] = None,
|
|
27
|
-
return_dict: Optional[bool] = None,
|
|
28
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
-
skip_logits: Optional[bool] = None,
|
|
30
|
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
31
|
-
r"""
|
|
32
|
-
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
37
|
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
38
|
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
39
|
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
|
|
43
|
-
Example:
|
|
44
|
-
|
|
45
|
-
```python
|
|
46
|
-
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
|
|
47
|
-
|
|
48
|
-
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
|
49
|
-
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
|
50
|
-
|
|
51
|
-
>>> prompt = "This is an example script ."
|
|
52
|
-
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
53
|
-
|
|
54
|
-
>>> # Generate
|
|
55
|
-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
56
|
-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
57
|
-
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
|
58
|
-
```"""
|
|
59
|
-
|
|
60
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
61
|
-
output_hidden_states = (
|
|
62
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
63
|
-
)
|
|
64
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
65
|
-
|
|
66
|
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
67
|
-
outputs = self.model(
|
|
68
|
-
input_ids=input_ids,
|
|
69
|
-
attention_mask=attention_mask,
|
|
70
|
-
position_ids=position_ids,
|
|
71
|
-
past_key_values=past_key_values,
|
|
72
|
-
inputs_embeds=inputs_embeds,
|
|
73
|
-
use_cache=use_cache,
|
|
74
|
-
output_attentions=output_attentions,
|
|
75
|
-
output_hidden_states=output_hidden_states,
|
|
76
|
-
return_dict=return_dict,
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
hidden_states = outputs[0]
|
|
80
|
-
|
|
81
|
-
loss = None
|
|
82
|
-
logits = None
|
|
83
|
-
|
|
84
|
-
if skip_logits and labels is None:
|
|
85
|
-
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
-
|
|
87
|
-
if skip_logits is None:
|
|
88
|
-
# By default, if in training mode, don't materialize logits
|
|
89
|
-
skip_logits = self.training and labels is not None
|
|
90
|
-
|
|
91
|
-
if skip_logits:
|
|
92
|
-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
93
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
94
|
-
|
|
95
|
-
# flatten tokens
|
|
96
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
97
|
-
shift_labels = shift_labels.view(-1)
|
|
98
|
-
|
|
99
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
100
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
101
|
-
else:
|
|
102
|
-
logits = self.lm_head(hidden_states)
|
|
103
|
-
|
|
104
|
-
loss = None
|
|
105
|
-
if labels is not None:
|
|
106
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
107
|
-
logits = logits.float()
|
|
108
|
-
# Shift so that tokens < n predict n
|
|
109
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
|
110
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
111
|
-
# Flatten the tokens
|
|
112
|
-
loss_fct = CrossEntropyLoss()
|
|
113
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
114
|
-
shift_labels = shift_labels.view(-1)
|
|
115
|
-
# Enable model parallelism
|
|
116
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
|
117
|
-
loss = loss_fct(shift_logits, shift_labels)
|
|
118
|
-
|
|
119
|
-
if not return_dict:
|
|
120
|
-
output = (logits,) + outputs[1:]
|
|
121
|
-
return (loss,) + output if loss is not None else output
|
|
122
|
-
|
|
123
|
-
return CausalLMOutputWithPast(
|
|
124
|
-
loss=loss,
|
|
125
|
-
logits=logits,
|
|
126
|
-
past_key_values=outputs.past_key_values,
|
|
127
|
-
hidden_states=outputs.hidden_states,
|
|
128
|
-
attentions=outputs.attentions,
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
133
14
|
def lce_forward(
|
|
134
15
|
self,
|
|
135
16
|
input_ids: torch.LongTensor = None,
|
|
@@ -148,73 +29,41 @@ def lce_forward(
|
|
|
148
29
|
**kwargs,
|
|
149
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
150
31
|
r"""
|
|
151
|
-
Args:
|
|
152
|
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
153
|
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
154
|
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
155
|
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
156
|
-
|
|
157
|
-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
158
|
-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
159
|
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
160
|
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
161
|
-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
162
|
-
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
163
|
-
|
|
164
|
-
Returns:
|
|
165
|
-
|
|
166
32
|
Example:
|
|
167
33
|
|
|
168
34
|
```python
|
|
169
35
|
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
|
|
170
36
|
|
|
171
|
-
>>> model = Phi3ForCausalLM.from_pretrained("
|
|
172
|
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
|
37
|
+
>>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
|
|
38
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
|
|
173
39
|
|
|
174
|
-
>>> prompt = "
|
|
40
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
175
41
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
176
42
|
|
|
177
43
|
>>> # Generate
|
|
178
44
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
179
45
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
180
|
-
|
|
46
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
181
47
|
```"""
|
|
182
48
|
|
|
183
|
-
from transformers.models.phi3.modeling_phi3 import logging
|
|
184
|
-
|
|
185
|
-
logger = logging.get_logger(__name__)
|
|
186
|
-
|
|
187
|
-
if (
|
|
188
|
-
use_cache
|
|
189
|
-
and self.config.rope_scaling
|
|
190
|
-
and cache_position is not None
|
|
191
|
-
and cache_position[0] == self.config.original_max_position_embeddings
|
|
192
|
-
):
|
|
193
|
-
logger.warning(
|
|
194
|
-
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
|
195
|
-
)
|
|
196
|
-
|
|
197
49
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
198
50
|
output_hidden_states = (
|
|
199
51
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
200
52
|
)
|
|
201
53
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
202
54
|
|
|
203
|
-
|
|
204
|
-
outputs = self.model(
|
|
55
|
+
outputs: BaseModelOutputWithPast = self.model(
|
|
205
56
|
input_ids=input_ids,
|
|
206
57
|
attention_mask=attention_mask,
|
|
207
58
|
position_ids=position_ids,
|
|
208
59
|
past_key_values=past_key_values,
|
|
209
60
|
inputs_embeds=inputs_embeds,
|
|
210
61
|
use_cache=use_cache,
|
|
211
|
-
|
|
212
|
-
output_hidden_states=output_hidden_states,
|
|
213
|
-
return_dict=return_dict,
|
|
62
|
+
cache_position=cache_position,
|
|
214
63
|
**kwargs,
|
|
215
64
|
)
|
|
216
65
|
|
|
217
|
-
hidden_states = outputs
|
|
66
|
+
hidden_states = outputs.last_hidden_state
|
|
218
67
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
219
68
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
220
69
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
@@ -242,10 +91,11 @@ def lce_forward(
|
|
|
242
91
|
|
|
243
92
|
else:
|
|
244
93
|
logits = self.lm_head(kept_hidden_states)
|
|
245
|
-
if labels is not None:
|
|
94
|
+
if labels is not None or shift_labels is not None:
|
|
246
95
|
loss = self.loss_function(
|
|
247
96
|
logits=logits,
|
|
248
97
|
labels=labels,
|
|
98
|
+
shift_labels=shift_labels,
|
|
249
99
|
vocab_size=self.config.vocab_size,
|
|
250
100
|
**kwargs,
|
|
251
101
|
)
|
|
@@ -228,10 +228,11 @@ def lce_forward(
|
|
|
228
228
|
|
|
229
229
|
else:
|
|
230
230
|
logits = self.lm_head(kept_hidden_states)
|
|
231
|
-
if labels is not None:
|
|
231
|
+
if labels is not None or shift_labels is not None:
|
|
232
232
|
loss = self.loss_function(
|
|
233
233
|
logits=logits,
|
|
234
234
|
labels=labels,
|
|
235
|
+
shift_labels=shift_labels,
|
|
235
236
|
vocab_size=self.config.vocab_size,
|
|
236
237
|
**kwargs,
|
|
237
238
|
)
|
|
@@ -133,8 +133,13 @@ def lce_forward(
|
|
|
133
133
|
logits = self.lm_head(hidden_states)
|
|
134
134
|
|
|
135
135
|
loss = None
|
|
136
|
-
if labels is not None:
|
|
137
|
-
loss = self.loss_function(
|
|
136
|
+
if labels is not None or shift_labels is not None:
|
|
137
|
+
loss = self.loss_function(
|
|
138
|
+
logits=logits,
|
|
139
|
+
labels=labels,
|
|
140
|
+
shift_labels=shift_labels,
|
|
141
|
+
vocab_size=self.config.vocab_size,
|
|
142
|
+
)
|
|
138
143
|
|
|
139
144
|
if not return_dict:
|
|
140
145
|
output = (logits,) + outputs[1:]
|
|
@@ -129,8 +129,13 @@ def lce_forward(
|
|
|
129
129
|
logits = self.lm_head(hidden_states)
|
|
130
130
|
|
|
131
131
|
loss = None
|
|
132
|
-
if labels is not None:
|
|
133
|
-
loss = self.loss_function(
|
|
132
|
+
if labels is not None or shift_labels is not None:
|
|
133
|
+
loss = self.loss_function(
|
|
134
|
+
logits=logits,
|
|
135
|
+
labels=labels,
|
|
136
|
+
shift_labels=shift_labels,
|
|
137
|
+
vocab_size=self.config.vocab_size,
|
|
138
|
+
)
|
|
134
139
|
|
|
135
140
|
return Qwen2VLCausalLMOutputWithPast(
|
|
136
141
|
loss=loss,
|
|
@@ -103,10 +103,11 @@ def lce_forward(
|
|
|
103
103
|
|
|
104
104
|
else:
|
|
105
105
|
logits = self.lm_head(kept_hidden_states)
|
|
106
|
-
if labels is not None:
|
|
106
|
+
if labels is not None or shift_labels is not None:
|
|
107
107
|
loss = self.loss_function(
|
|
108
108
|
logits=logits,
|
|
109
109
|
labels=labels,
|
|
110
|
+
shift_labels=shift_labels,
|
|
110
111
|
vocab_size=self.config.vocab_size,
|
|
111
112
|
**kwargs,
|
|
112
113
|
)
|
|
@@ -107,8 +107,14 @@ def lce_forward(
|
|
|
107
107
|
)
|
|
108
108
|
else: # if in inference model materialize logits
|
|
109
109
|
logits = self.lm_head(kept_hidden_states)
|
|
110
|
-
if labels is not None:
|
|
111
|
-
loss = self.loss_function(
|
|
110
|
+
if labels is not None or shift_labels is not None:
|
|
111
|
+
loss = self.loss_function(
|
|
112
|
+
logits=logits,
|
|
113
|
+
labels=labels,
|
|
114
|
+
shift_labels=shift_labels,
|
|
115
|
+
vocab_size=self.vocab_size,
|
|
116
|
+
**kwargs,
|
|
117
|
+
)
|
|
112
118
|
|
|
113
119
|
aux_loss = None
|
|
114
120
|
if output_router_logits:
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
9
|
+
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func
|
|
13
|
+
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def lce_forward(
|
|
18
|
+
self,
|
|
19
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
21
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
22
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
23
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
24
|
+
labels: Optional[torch.LongTensor] = None,
|
|
25
|
+
use_cache: Optional[bool] = None,
|
|
26
|
+
output_attentions: Optional[bool] = None,
|
|
27
|
+
output_hidden_states: Optional[bool] = None,
|
|
28
|
+
output_router_logits: Optional[bool] = None,
|
|
29
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
30
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
31
|
+
skip_logits: Optional[bool] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> MoeCausalLMOutputWithPast:
|
|
34
|
+
r"""
|
|
35
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
36
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
37
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
38
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
39
|
+
|
|
40
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
41
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
42
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
43
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
44
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
45
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
53
|
+
|
|
54
|
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
|
|
55
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
|
|
56
|
+
|
|
57
|
+
>>> prompt = "Give me a short introduction to large language model."
|
|
58
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
59
|
+
|
|
60
|
+
>>> # Generate
|
|
61
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
62
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
63
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
64
|
+
```"""
|
|
65
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
66
|
+
output_router_logits = (
|
|
67
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
output_hidden_states = (
|
|
71
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
75
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
76
|
+
input_ids=input_ids,
|
|
77
|
+
attention_mask=attention_mask,
|
|
78
|
+
position_ids=position_ids,
|
|
79
|
+
past_key_values=past_key_values,
|
|
80
|
+
inputs_embeds=inputs_embeds,
|
|
81
|
+
use_cache=use_cache,
|
|
82
|
+
output_attentions=output_attentions,
|
|
83
|
+
output_hidden_states=output_hidden_states,
|
|
84
|
+
output_router_logits=output_router_logits,
|
|
85
|
+
cache_position=cache_position,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
hidden_states = outputs.last_hidden_state
|
|
90
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
91
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
92
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
93
|
+
|
|
94
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
95
|
+
logits = None
|
|
96
|
+
loss = None
|
|
97
|
+
|
|
98
|
+
if skip_logits is None:
|
|
99
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
100
|
+
|
|
101
|
+
if skip_logits:
|
|
102
|
+
loss = LigerForCausalLMLoss(
|
|
103
|
+
hidden_states=kept_hidden_states,
|
|
104
|
+
lm_head_weight=self.lm_head.weight,
|
|
105
|
+
labels=labels,
|
|
106
|
+
shift_labels=shift_labels,
|
|
107
|
+
hidden_size=self.config.hidden_size,
|
|
108
|
+
**kwargs,
|
|
109
|
+
)
|
|
110
|
+
else: # if in inference model materialize logits
|
|
111
|
+
logits = self.lm_head(kept_hidden_states)
|
|
112
|
+
if labels is not None or shift_labels is not None:
|
|
113
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
114
|
+
|
|
115
|
+
aux_loss = None
|
|
116
|
+
if output_router_logits:
|
|
117
|
+
aux_loss = load_balancing_loss_func(
|
|
118
|
+
outputs.router_logits,
|
|
119
|
+
self.num_experts,
|
|
120
|
+
self.num_experts_per_tok,
|
|
121
|
+
attention_mask,
|
|
122
|
+
)
|
|
123
|
+
if labels is not None:
|
|
124
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
125
|
+
|
|
126
|
+
return MoeCausalLMOutputWithPast(
|
|
127
|
+
loss=loss,
|
|
128
|
+
aux_loss=aux_loss,
|
|
129
|
+
logits=logits,
|
|
130
|
+
past_key_values=outputs.past_key_values,
|
|
131
|
+
hidden_states=outputs.hidden_states,
|
|
132
|
+
attentions=outputs.attentions,
|
|
133
|
+
router_logits=outputs.router_logits,
|
|
134
|
+
)
|
|
@@ -121,10 +121,11 @@ def lce_forward(
|
|
|
121
121
|
|
|
122
122
|
else:
|
|
123
123
|
logits = self.lm_head(kept_hidden_states)
|
|
124
|
-
if labels is not None:
|
|
124
|
+
if labels is not None or shift_labels is not None:
|
|
125
125
|
loss = self.loss_function(
|
|
126
126
|
logits=logits,
|
|
127
127
|
labels=labels,
|
|
128
|
+
shift_labels=shift_labels,
|
|
128
129
|
vocab_size=self.config.vocab_size,
|
|
129
130
|
**kwargs,
|
|
130
131
|
)
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast
|
|
8
|
+
from transformers.processing_utils import Unpack
|
|
9
|
+
from transformers.utils.generic import can_return_tuple
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from transformers.cache_utils import Cache
|
|
15
|
+
from transformers.utils.generic import TransformersKwargs
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Forward adapted to enable fused Linear + CE without materializing logits.
|
|
19
|
+
# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA).
|
|
20
|
+
@can_return_tuple
|
|
21
|
+
def lce_forward(
|
|
22
|
+
self,
|
|
23
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
24
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
25
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
26
|
+
past_key_values: Optional["Cache"] = None,
|
|
27
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
28
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
29
|
+
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
30
|
+
image_hidden_states: Optional[torch.FloatTensor] = None,
|
|
31
|
+
labels: Optional[torch.LongTensor] = None,
|
|
32
|
+
use_cache: Optional[bool] = None,
|
|
33
|
+
output_attentions: Optional[bool] = None,
|
|
34
|
+
output_hidden_states: Optional[bool] = None,
|
|
35
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
36
|
+
return_dict: Optional[bool] = None,
|
|
37
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
38
|
+
skip_logits: Optional[bool] = None, # Added argument for liger-kernel
|
|
39
|
+
**lm_kwargs: Unpack["TransformersKwargs"], # renamed from kwargs
|
|
40
|
+
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
|
|
41
|
+
r"""
|
|
42
|
+
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
|
43
|
+
Mask to avoid performing attention on padding pixel indices.
|
|
44
|
+
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
45
|
+
The hidden states of the image encoder after modality projection.
|
|
46
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
47
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
48
|
+
config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
|
|
49
|
+
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
>>> import requests
|
|
55
|
+
>>> import torch
|
|
56
|
+
>>> from PIL import Image
|
|
57
|
+
>>> from io import BytesIO
|
|
58
|
+
|
|
59
|
+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
60
|
+
>>> from transformers.image_utils import load_image
|
|
61
|
+
|
|
62
|
+
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
|
63
|
+
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
|
64
|
+
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
|
65
|
+
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
|
66
|
+
|
|
67
|
+
>>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
|
68
|
+
>>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
|
69
|
+
|
|
70
|
+
>>> # Create inputs
|
|
71
|
+
>>> messages = [
|
|
72
|
+
... {
|
|
73
|
+
... "role": "user",
|
|
74
|
+
... "content": [
|
|
75
|
+
... {"type": "video", "path": path/to/video},
|
|
76
|
+
... {"type": "text", "text": "What is happening in this video?"},
|
|
77
|
+
... ]
|
|
78
|
+
... }
|
|
79
|
+
... ]
|
|
80
|
+
|
|
81
|
+
>>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
|
|
82
|
+
|
|
83
|
+
>>> # Generate
|
|
84
|
+
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
|
85
|
+
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
86
|
+
|
|
87
|
+
>>> print(generated_texts)
|
|
88
|
+
```"""
|
|
89
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
90
|
+
output_hidden_states = (
|
|
91
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
92
|
+
)
|
|
93
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
94
|
+
|
|
95
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
96
|
+
outputs = self.model(
|
|
97
|
+
input_ids=input_ids,
|
|
98
|
+
attention_mask=attention_mask,
|
|
99
|
+
position_ids=position_ids,
|
|
100
|
+
past_key_values=past_key_values,
|
|
101
|
+
inputs_embeds=inputs_embeds,
|
|
102
|
+
pixel_values=pixel_values,
|
|
103
|
+
pixel_attention_mask=pixel_attention_mask,
|
|
104
|
+
image_hidden_states=image_hidden_states,
|
|
105
|
+
use_cache=use_cache,
|
|
106
|
+
output_attentions=output_attentions,
|
|
107
|
+
output_hidden_states=output_hidden_states,
|
|
108
|
+
cache_position=cache_position,
|
|
109
|
+
return_dict=True,
|
|
110
|
+
**lm_kwargs,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Copied from llava.py
|
|
114
|
+
hidden_states = outputs[0]
|
|
115
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
116
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
117
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
118
|
+
|
|
119
|
+
shift_labels = lm_kwargs.pop("shift_labels", None)
|
|
120
|
+
logits = None
|
|
121
|
+
loss = None
|
|
122
|
+
|
|
123
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
124
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
125
|
+
|
|
126
|
+
if skip_logits is None:
|
|
127
|
+
# By default, if in training mode, don't materialize logits
|
|
128
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
129
|
+
|
|
130
|
+
if skip_logits:
|
|
131
|
+
loss = LigerForCausalLMLoss(
|
|
132
|
+
hidden_states=kept_hidden_states,
|
|
133
|
+
lm_head_weight=self.lm_head.weight,
|
|
134
|
+
labels=labels,
|
|
135
|
+
shift_labels=shift_labels,
|
|
136
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
137
|
+
**lm_kwargs,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
logits = self.lm_head(kept_hidden_states)
|
|
142
|
+
if labels is not None or shift_labels is not None:
|
|
143
|
+
loss = self.loss_function(
|
|
144
|
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if not return_dict:
|
|
148
|
+
output = (logits,) + outputs[1:]
|
|
149
|
+
return (loss,) + output if loss is not None else output
|
|
150
|
+
|
|
151
|
+
return SmolVLMCausalLMOutputWithPast(
|
|
152
|
+
loss=loss,
|
|
153
|
+
logits=logits,
|
|
154
|
+
past_key_values=outputs.past_key_values,
|
|
155
|
+
hidden_states=outputs.hidden_states,
|
|
156
|
+
attentions=outputs.attentions,
|
|
157
|
+
image_hidden_states=outputs.image_hidden_states,
|
|
158
|
+
)
|