liger-kernel-nightly 0.6.3.dev20251105190428__py3-none-any.whl → 0.6.3.dev20251105235313__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +59 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +27 -4
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +24 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +19 -4
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +18 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +11 -5
- liger_kernel/transformers/model/qwen3_vl_moe.py +12 -5
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +4 -2
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/RECORD +40 -39
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,8 @@ from transformers.utils.deprecation import deprecate_kwarg
|
|
|
11
11
|
|
|
12
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
13
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def lce_forward_deprecated(
|
|
@@ -145,7 +147,7 @@ def lce_forward(
|
|
|
145
147
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
146
148
|
skip_logits: Optional[bool] = None,
|
|
147
149
|
**kwargs,
|
|
148
|
-
) -> Union[Tuple,
|
|
150
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
149
151
|
r"""
|
|
150
152
|
Args:
|
|
151
153
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -208,6 +210,7 @@ def lce_forward(
|
|
|
208
210
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
209
211
|
logits = None
|
|
210
212
|
loss = None
|
|
213
|
+
token_accuracy = None
|
|
211
214
|
|
|
212
215
|
if skip_logits and labels is None and shift_labels is None:
|
|
213
216
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -216,8 +219,9 @@ def lce_forward(
|
|
|
216
219
|
# By default, if in training mode, don't materialize logits
|
|
217
220
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
218
221
|
|
|
222
|
+
# Compute loss
|
|
219
223
|
if skip_logits:
|
|
220
|
-
|
|
224
|
+
result = LigerForCausalLMLoss(
|
|
221
225
|
hidden_states=kept_hidden_states,
|
|
222
226
|
lm_head_weight=self.lm_head.weight,
|
|
223
227
|
labels=labels,
|
|
@@ -225,6 +229,7 @@ def lce_forward(
|
|
|
225
229
|
hidden_size=self.config.hidden_size,
|
|
226
230
|
**kwargs,
|
|
227
231
|
)
|
|
232
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
228
233
|
|
|
229
234
|
else:
|
|
230
235
|
logits = self.lm_head(kept_hidden_states)
|
|
@@ -237,10 +242,18 @@ def lce_forward(
|
|
|
237
242
|
**kwargs,
|
|
238
243
|
)
|
|
239
244
|
|
|
240
|
-
|
|
245
|
+
if not return_dict:
|
|
246
|
+
output_tuple = (logits,) + outputs[1:]
|
|
247
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
248
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
249
|
+
return output
|
|
250
|
+
|
|
251
|
+
# Return custom output class with token accuracy field
|
|
252
|
+
return LigerCausalLMOutputWithPast(
|
|
241
253
|
loss=loss,
|
|
242
254
|
logits=logits,
|
|
243
255
|
past_key_values=outputs.past_key_values,
|
|
244
256
|
hidden_states=outputs.hidden_states,
|
|
245
257
|
attentions=outputs.attentions,
|
|
258
|
+
token_accuracy=token_accuracy,
|
|
246
259
|
)
|
|
@@ -5,10 +5,11 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
|
9
8
|
from transformers.utils import can_return_tuple
|
|
10
9
|
|
|
11
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@can_return_tuple
|
|
@@ -33,7 +34,7 @@ def lce_forward(
|
|
|
33
34
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
34
35
|
skip_logits: Optional[bool] = None,
|
|
35
36
|
**kwargs,
|
|
36
|
-
) -> Union[Tuple,
|
|
37
|
+
) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]:
|
|
37
38
|
r"""
|
|
38
39
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
40
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -113,6 +114,7 @@ def lce_forward(
|
|
|
113
114
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
114
115
|
loss = None
|
|
115
116
|
logits = None
|
|
117
|
+
token_accuracy = None
|
|
116
118
|
|
|
117
119
|
if skip_logits and labels is None and shift_labels is None:
|
|
118
120
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -120,8 +122,9 @@ def lce_forward(
|
|
|
120
122
|
if skip_logits is None:
|
|
121
123
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
122
124
|
|
|
125
|
+
# Compute loss
|
|
123
126
|
if skip_logits:
|
|
124
|
-
|
|
127
|
+
result = LigerForCausalLMLoss(
|
|
125
128
|
hidden_states=hidden_states,
|
|
126
129
|
lm_head_weight=self.lm_head.weight,
|
|
127
130
|
labels=labels,
|
|
@@ -129,6 +132,7 @@ def lce_forward(
|
|
|
129
132
|
hidden_size=self.config.hidden_size,
|
|
130
133
|
**kwargs,
|
|
131
134
|
)
|
|
135
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
132
136
|
else:
|
|
133
137
|
logits = self.lm_head(hidden_states)
|
|
134
138
|
|
|
@@ -142,14 +146,18 @@ def lce_forward(
|
|
|
142
146
|
)
|
|
143
147
|
|
|
144
148
|
if not return_dict:
|
|
145
|
-
|
|
146
|
-
|
|
149
|
+
output_tuple = (logits,) + outputs[1:]
|
|
150
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
151
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
152
|
+
return output
|
|
147
153
|
|
|
148
|
-
|
|
154
|
+
# Return Qwen2.5-VL output with token accuracy
|
|
155
|
+
return LigerQwen2_5_VLCausalLMOutputWithPast(
|
|
149
156
|
loss=loss,
|
|
150
157
|
logits=logits,
|
|
151
158
|
past_key_values=outputs.past_key_values,
|
|
152
159
|
hidden_states=outputs.hidden_states,
|
|
153
160
|
attentions=outputs.attentions,
|
|
154
161
|
rope_deltas=outputs.rope_deltas,
|
|
162
|
+
token_accuracy=token_accuracy,
|
|
155
163
|
)
|
|
@@ -5,10 +5,11 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
|
9
8
|
from transformers.utils import can_return_tuple
|
|
10
9
|
|
|
11
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@can_return_tuple
|
|
@@ -32,7 +33,7 @@ def lce_forward(
|
|
|
32
33
|
cache_position: Optional[torch.LongTensor] = None,
|
|
33
34
|
skip_logits: Optional[bool] = None,
|
|
34
35
|
**kwargs,
|
|
35
|
-
) -> Union[Tuple,
|
|
36
|
+
) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]:
|
|
36
37
|
r"""
|
|
37
38
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
38
39
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -109,6 +110,7 @@ def lce_forward(
|
|
|
109
110
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
110
111
|
loss = None
|
|
111
112
|
logits = None
|
|
113
|
+
token_accuracy = None
|
|
112
114
|
|
|
113
115
|
if skip_logits and labels is None and shift_labels is None:
|
|
114
116
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -116,8 +118,9 @@ def lce_forward(
|
|
|
116
118
|
if skip_logits is None:
|
|
117
119
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
118
120
|
|
|
121
|
+
# Compute loss
|
|
119
122
|
if skip_logits:
|
|
120
|
-
|
|
123
|
+
result = LigerForCausalLMLoss(
|
|
121
124
|
hidden_states=hidden_states,
|
|
122
125
|
lm_head_weight=self.lm_head.weight,
|
|
123
126
|
labels=labels,
|
|
@@ -125,6 +128,7 @@ def lce_forward(
|
|
|
125
128
|
hidden_size=self.config.hidden_size,
|
|
126
129
|
**kwargs,
|
|
127
130
|
)
|
|
131
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
128
132
|
else:
|
|
129
133
|
logits = self.lm_head(hidden_states)
|
|
130
134
|
|
|
@@ -137,11 +141,19 @@ def lce_forward(
|
|
|
137
141
|
vocab_size=self.config.vocab_size,
|
|
138
142
|
)
|
|
139
143
|
|
|
140
|
-
|
|
144
|
+
if not return_dict:
|
|
145
|
+
output_tuple = (logits,) + outputs[1:]
|
|
146
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
147
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
148
|
+
return output
|
|
149
|
+
|
|
150
|
+
# Return Qwen2VL output with token accuracy
|
|
151
|
+
return LigerQwen2VLCausalLMOutputWithPast(
|
|
141
152
|
loss=loss,
|
|
142
153
|
logits=logits,
|
|
143
154
|
past_key_values=outputs.past_key_values,
|
|
144
155
|
hidden_states=outputs.hidden_states,
|
|
145
156
|
attentions=outputs.attentions,
|
|
146
157
|
rope_deltas=outputs.rope_deltas,
|
|
158
|
+
token_accuracy=token_accuracy,
|
|
147
159
|
)
|
|
@@ -4,9 +4,9 @@ from typing import Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
-
|
|
9
7
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
8
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
9
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def lce_forward(
|
|
@@ -23,8 +23,9 @@ def lce_forward(
|
|
|
23
23
|
cache_position: Optional[torch.LongTensor] = None,
|
|
24
24
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
25
25
|
skip_logits: Optional[bool] = None,
|
|
26
|
+
return_dict: Optional[bool] = None,
|
|
26
27
|
**kwargs,
|
|
27
|
-
) ->
|
|
28
|
+
) -> LigerCausalLMOutputWithPast:
|
|
28
29
|
r"""
|
|
29
30
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
30
31
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -60,6 +61,7 @@ def lce_forward(
|
|
|
60
61
|
output_hidden_states = (
|
|
61
62
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
62
63
|
)
|
|
64
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
63
65
|
|
|
64
66
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
65
67
|
outputs = self.model(
|
|
@@ -83,6 +85,7 @@ def lce_forward(
|
|
|
83
85
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
84
86
|
logits = None
|
|
85
87
|
loss = None
|
|
88
|
+
token_accuracy = None
|
|
86
89
|
|
|
87
90
|
if skip_logits and labels is None and shift_labels is None:
|
|
88
91
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -91,8 +94,9 @@ def lce_forward(
|
|
|
91
94
|
# By default, if in training mode, don't materialize logits
|
|
92
95
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
93
96
|
|
|
97
|
+
# Compute loss
|
|
94
98
|
if skip_logits:
|
|
95
|
-
|
|
99
|
+
result = LigerForCausalLMLoss(
|
|
96
100
|
hidden_states=kept_hidden_states,
|
|
97
101
|
lm_head_weight=self.lm_head.weight,
|
|
98
102
|
labels=labels,
|
|
@@ -100,6 +104,7 @@ def lce_forward(
|
|
|
100
104
|
hidden_size=self.config.hidden_size,
|
|
101
105
|
**kwargs,
|
|
102
106
|
)
|
|
107
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
103
108
|
|
|
104
109
|
else:
|
|
105
110
|
logits = self.lm_head(kept_hidden_states)
|
|
@@ -112,10 +117,18 @@ def lce_forward(
|
|
|
112
117
|
**kwargs,
|
|
113
118
|
)
|
|
114
119
|
|
|
115
|
-
|
|
120
|
+
if not return_dict:
|
|
121
|
+
output = (logits,) + outputs[1:]
|
|
122
|
+
output = ((loss,) + output) if loss is not None else output
|
|
123
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
124
|
+
return output
|
|
125
|
+
|
|
126
|
+
# Return custom output class with accuracy field
|
|
127
|
+
return LigerCausalLMOutputWithPast(
|
|
116
128
|
loss=loss,
|
|
117
129
|
logits=logits,
|
|
118
130
|
past_key_values=outputs.past_key_values,
|
|
119
131
|
hidden_states=outputs.hidden_states,
|
|
120
132
|
attentions=outputs.attentions,
|
|
133
|
+
token_accuracy=token_accuracy,
|
|
121
134
|
)
|
|
@@ -4,11 +4,12 @@ from typing import Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
8
7
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
9
8
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
10
9
|
|
|
11
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def lce_forward(
|
|
@@ -26,8 +27,9 @@ def lce_forward(
|
|
|
26
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
27
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
29
|
skip_logits: Optional[bool] = None,
|
|
30
|
+
return_dict: Optional[bool] = None,
|
|
29
31
|
**kwargs,
|
|
30
|
-
) ->
|
|
32
|
+
) -> LigerMoeCausalLMOutputWithPast:
|
|
31
33
|
r"""
|
|
32
34
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
33
35
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -64,10 +66,10 @@ def lce_forward(
|
|
|
64
66
|
output_router_logits = (
|
|
65
67
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
66
68
|
)
|
|
67
|
-
|
|
68
69
|
output_hidden_states = (
|
|
69
70
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
70
71
|
)
|
|
72
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
71
73
|
|
|
72
74
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
73
75
|
outputs: MoeModelOutputWithPast = self.model(
|
|
@@ -92,12 +94,14 @@ def lce_forward(
|
|
|
92
94
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
93
95
|
logits = None
|
|
94
96
|
loss = None
|
|
97
|
+
token_accuracy = None
|
|
95
98
|
|
|
96
99
|
if skip_logits is None:
|
|
97
100
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
98
101
|
|
|
102
|
+
# Compute loss
|
|
99
103
|
if skip_logits:
|
|
100
|
-
|
|
104
|
+
result = LigerForCausalLMLoss(
|
|
101
105
|
hidden_states=kept_hidden_states,
|
|
102
106
|
lm_head_weight=self.lm_head.weight,
|
|
103
107
|
labels=labels,
|
|
@@ -105,6 +109,7 @@ def lce_forward(
|
|
|
105
109
|
hidden_size=self.config.hidden_size,
|
|
106
110
|
**kwargs,
|
|
107
111
|
)
|
|
112
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
108
113
|
else: # if in inference model materialize logits
|
|
109
114
|
logits = self.lm_head(kept_hidden_states)
|
|
110
115
|
if labels is not None or shift_labels is not None:
|
|
@@ -127,7 +132,15 @@ def lce_forward(
|
|
|
127
132
|
if labels is not None:
|
|
128
133
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
129
134
|
|
|
130
|
-
|
|
135
|
+
if not return_dict:
|
|
136
|
+
output = (logits,) + outputs[1:]
|
|
137
|
+
output = ((aux_loss,) + output) if aux_loss is not None else output
|
|
138
|
+
output = ((loss,) + output) if loss is not None else output
|
|
139
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
140
|
+
return output
|
|
141
|
+
|
|
142
|
+
# Return custom output class with accuracy field
|
|
143
|
+
return LigerMoeCausalLMOutputWithPast(
|
|
131
144
|
loss=loss,
|
|
132
145
|
aux_loss=aux_loss,
|
|
133
146
|
logits=logits,
|
|
@@ -135,4 +148,5 @@ def lce_forward(
|
|
|
135
148
|
hidden_states=outputs.hidden_states,
|
|
136
149
|
attentions=outputs.attentions,
|
|
137
150
|
router_logits=outputs.router_logits,
|
|
151
|
+
token_accuracy=token_accuracy,
|
|
138
152
|
)
|
|
@@ -5,13 +5,14 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
9
8
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
10
9
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func
|
|
13
12
|
|
|
14
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def lce_forward(
|
|
@@ -29,8 +30,9 @@ def lce_forward(
|
|
|
29
30
|
cache_position: Optional[torch.LongTensor] = None,
|
|
30
31
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
31
32
|
skip_logits: Optional[bool] = None,
|
|
33
|
+
return_dict: Optional[bool] = None,
|
|
32
34
|
**kwargs,
|
|
33
|
-
) ->
|
|
35
|
+
) -> LigerMoeCausalLMOutputWithPast:
|
|
34
36
|
r"""
|
|
35
37
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
36
38
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -66,10 +68,10 @@ def lce_forward(
|
|
|
66
68
|
output_router_logits = (
|
|
67
69
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
68
70
|
)
|
|
69
|
-
|
|
70
71
|
output_hidden_states = (
|
|
71
72
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
72
73
|
)
|
|
74
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
73
75
|
|
|
74
76
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
75
77
|
outputs: MoeModelOutputWithPast = self.model(
|
|
@@ -94,12 +96,13 @@ def lce_forward(
|
|
|
94
96
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
95
97
|
logits = None
|
|
96
98
|
loss = None
|
|
99
|
+
token_accuracy = None
|
|
97
100
|
|
|
98
101
|
if skip_logits is None:
|
|
99
102
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
100
103
|
|
|
101
104
|
if skip_logits:
|
|
102
|
-
|
|
105
|
+
result = LigerForCausalLMLoss(
|
|
103
106
|
hidden_states=kept_hidden_states,
|
|
104
107
|
lm_head_weight=self.lm_head.weight,
|
|
105
108
|
labels=labels,
|
|
@@ -107,6 +110,7 @@ def lce_forward(
|
|
|
107
110
|
hidden_size=self.config.hidden_size,
|
|
108
111
|
**kwargs,
|
|
109
112
|
)
|
|
113
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
110
114
|
else: # if in inference model materialize logits
|
|
111
115
|
logits = self.lm_head(kept_hidden_states)
|
|
112
116
|
if labels is not None or shift_labels is not None:
|
|
@@ -123,7 +127,14 @@ def lce_forward(
|
|
|
123
127
|
if labels is not None:
|
|
124
128
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
125
129
|
|
|
126
|
-
|
|
130
|
+
if not return_dict:
|
|
131
|
+
output = (logits,) + outputs[1:]
|
|
132
|
+
output = ((aux_loss,) + output) if aux_loss is not None else output
|
|
133
|
+
output = ((loss,) + output) if loss is not None else output
|
|
134
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
135
|
+
return output
|
|
136
|
+
|
|
137
|
+
return LigerMoeCausalLMOutputWithPast(
|
|
127
138
|
loss=loss,
|
|
128
139
|
aux_loss=aux_loss,
|
|
129
140
|
logits=logits,
|
|
@@ -131,4 +142,5 @@ def lce_forward(
|
|
|
131
142
|
hidden_states=outputs.hidden_states,
|
|
132
143
|
attentions=outputs.attentions,
|
|
133
144
|
router_logits=outputs.router_logits,
|
|
145
|
+
token_accuracy=token_accuracy,
|
|
134
146
|
)
|
|
@@ -5,10 +5,11 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLCausalLMOutputWithPast
|
|
9
8
|
from transformers.utils import can_return_tuple
|
|
10
9
|
|
|
11
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen3VLCausalLMOutputWithPast
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@can_return_tuple
|
|
@@ -33,7 +34,7 @@ def lce_forward(
|
|
|
33
34
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
34
35
|
skip_logits: Optional[bool] = None,
|
|
35
36
|
**kwargs,
|
|
36
|
-
) -> Union[Tuple,
|
|
37
|
+
) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]:
|
|
37
38
|
"""
|
|
38
39
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
40
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -107,6 +108,7 @@ def lce_forward(
|
|
|
107
108
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
108
109
|
loss = None
|
|
109
110
|
logits = None
|
|
111
|
+
token_accuracy = None
|
|
110
112
|
|
|
111
113
|
if skip_logits and labels is None and shift_labels is None:
|
|
112
114
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -115,7 +117,7 @@ def lce_forward(
|
|
|
115
117
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
116
118
|
|
|
117
119
|
if skip_logits:
|
|
118
|
-
|
|
120
|
+
result = LigerForCausalLMLoss(
|
|
119
121
|
hidden_states=hidden_states,
|
|
120
122
|
lm_head_weight=self.lm_head.weight,
|
|
121
123
|
labels=labels,
|
|
@@ -123,6 +125,7 @@ def lce_forward(
|
|
|
123
125
|
hidden_size=self.config.text_config.hidden_size,
|
|
124
126
|
**kwargs,
|
|
125
127
|
)
|
|
128
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
126
129
|
else:
|
|
127
130
|
logits = self.lm_head(hidden_states)
|
|
128
131
|
|
|
@@ -132,13 +135,16 @@ def lce_forward(
|
|
|
132
135
|
|
|
133
136
|
if not return_dict:
|
|
134
137
|
output = (logits,) + outputs[1:]
|
|
135
|
-
|
|
138
|
+
output = (loss,) + output if loss is not None else output
|
|
139
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
140
|
+
return output
|
|
136
141
|
|
|
137
|
-
return
|
|
142
|
+
return LigerQwen3VLCausalLMOutputWithPast(
|
|
138
143
|
loss=loss,
|
|
139
144
|
logits=logits,
|
|
140
145
|
past_key_values=outputs.past_key_values,
|
|
141
146
|
hidden_states=outputs.hidden_states,
|
|
142
147
|
attentions=outputs.attentions,
|
|
143
148
|
rope_deltas=outputs.rope_deltas,
|
|
149
|
+
token_accuracy=token_accuracy,
|
|
144
150
|
)
|
|
@@ -5,11 +5,12 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeCausalLMOutputWithPast
|
|
9
8
|
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func
|
|
10
9
|
from transformers.utils import can_return_tuple
|
|
11
10
|
|
|
12
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
13
|
+
from liger_kernel.transformers.model.output_classes import LigerQwen3VLMoeCausalLMOutputWithPast
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
@can_return_tuple
|
|
@@ -34,7 +35,7 @@ def lce_forward(
|
|
|
34
35
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
35
36
|
skip_logits: Optional[bool] = None,
|
|
36
37
|
**kwargs,
|
|
37
|
-
) -> Union[Tuple,
|
|
38
|
+
) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]:
|
|
38
39
|
"""
|
|
39
40
|
Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
|
|
40
41
|
"""
|
|
@@ -69,6 +70,7 @@ def lce_forward(
|
|
|
69
70
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
70
71
|
loss = None
|
|
71
72
|
logits = None
|
|
73
|
+
token_accuracy = None
|
|
72
74
|
|
|
73
75
|
if skip_logits and labels is None and shift_labels is None:
|
|
74
76
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -77,7 +79,7 @@ def lce_forward(
|
|
|
77
79
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
78
80
|
|
|
79
81
|
if skip_logits:
|
|
80
|
-
|
|
82
|
+
result = LigerForCausalLMLoss(
|
|
81
83
|
hidden_states=hidden_states,
|
|
82
84
|
lm_head_weight=self.lm_head.weight,
|
|
83
85
|
labels=labels,
|
|
@@ -85,6 +87,7 @@ def lce_forward(
|
|
|
85
87
|
hidden_size=self.config.text_config.hidden_size,
|
|
86
88
|
**kwargs,
|
|
87
89
|
)
|
|
90
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
88
91
|
else:
|
|
89
92
|
logits = self.lm_head(hidden_states)
|
|
90
93
|
|
|
@@ -106,9 +109,12 @@ def lce_forward(
|
|
|
106
109
|
|
|
107
110
|
if not return_dict:
|
|
108
111
|
output = (logits,) + outputs[1:]
|
|
109
|
-
|
|
112
|
+
output = (loss,) + output if loss is not None else output
|
|
113
|
+
output = output + (aux_loss,) if aux_loss is not None else output
|
|
114
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
115
|
+
return output
|
|
110
116
|
|
|
111
|
-
return
|
|
117
|
+
return LigerQwen3VLMoeCausalLMOutputWithPast(
|
|
112
118
|
loss=loss,
|
|
113
119
|
logits=logits,
|
|
114
120
|
past_key_values=outputs.past_key_values,
|
|
@@ -116,4 +122,5 @@ def lce_forward(
|
|
|
116
122
|
attentions=outputs.attentions,
|
|
117
123
|
rope_deltas=outputs.rope_deltas,
|
|
118
124
|
aux_loss=aux_loss,
|
|
125
|
+
token_accuracy=token_accuracy,
|
|
119
126
|
)
|
|
@@ -7,11 +7,12 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
12
11
|
|
|
13
12
|
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
14
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
15
16
|
from liger_kernel.utils import PEFT_AVAILABLE
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
@@ -38,7 +39,7 @@ def lce_forward(
|
|
|
38
39
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
39
40
|
skip_logits: Optional[bool] = None,
|
|
40
41
|
**kwargs,
|
|
41
|
-
) -> Union[Tuple,
|
|
42
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
42
43
|
r"""
|
|
43
44
|
Args:
|
|
44
45
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -101,6 +102,8 @@ def lce_forward(
|
|
|
101
102
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
102
103
|
logits = None
|
|
103
104
|
loss = None
|
|
105
|
+
token_accuracy = None
|
|
106
|
+
|
|
104
107
|
# if in training mode, don't materialize logits
|
|
105
108
|
if skip_logits and labels is None and shift_labels is None:
|
|
106
109
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
@@ -109,8 +112,9 @@ def lce_forward(
|
|
|
109
112
|
# By default, if in training mode, don't materialize logits
|
|
110
113
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
111
114
|
|
|
115
|
+
# Compute loss
|
|
112
116
|
if skip_logits:
|
|
113
|
-
|
|
117
|
+
result = lce_maybe_trainable_lm_head(
|
|
114
118
|
self,
|
|
115
119
|
hidden_states=kept_hidden_states,
|
|
116
120
|
hidden_size=self.config.hidden_size,
|
|
@@ -118,6 +122,7 @@ def lce_forward(
|
|
|
118
122
|
shift_labels=shift_labels,
|
|
119
123
|
**kwargs,
|
|
120
124
|
)
|
|
125
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
121
126
|
|
|
122
127
|
else:
|
|
123
128
|
logits = self.lm_head(kept_hidden_states)
|
|
@@ -131,15 +136,19 @@ def lce_forward(
|
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if not return_dict:
|
|
134
|
-
|
|
135
|
-
|
|
139
|
+
output_tuple = (logits,) + outputs[1:]
|
|
140
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
141
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
142
|
+
return output
|
|
136
143
|
|
|
137
|
-
|
|
144
|
+
# Return custom output class with token_accuracy field
|
|
145
|
+
return LigerCausalLMOutputWithPast(
|
|
138
146
|
loss=loss,
|
|
139
147
|
logits=logits,
|
|
140
148
|
past_key_values=outputs.past_key_values,
|
|
141
149
|
hidden_states=outputs.hidden_states,
|
|
142
150
|
attentions=outputs.attentions,
|
|
151
|
+
token_accuracy=token_accuracy,
|
|
143
152
|
)
|
|
144
153
|
|
|
145
154
|
|
|
@@ -30,8 +30,6 @@ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mi
|
|
|
30
30
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
31
31
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
32
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
33
|
-
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
34
|
-
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
35
33
|
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
36
34
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
37
35
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
@@ -1679,6 +1677,8 @@ def apply_liger_kernel_to_qwen3_vl(
|
|
|
1679
1677
|
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1680
1678
|
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1681
1679
|
|
|
1680
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1681
|
+
|
|
1682
1682
|
if rope:
|
|
1683
1683
|
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1684
1684
|
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
@@ -1752,6 +1752,8 @@ def apply_liger_kernel_to_qwen3_vl_moe(
|
|
|
1752
1752
|
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1753
1753
|
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1754
1754
|
|
|
1755
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1756
|
+
|
|
1755
1757
|
if rope:
|
|
1756
1758
|
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1757
1759
|
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|