liger-kernel-nightly 0.5.10.dev20250605210201__py3-none-any.whl → 0.5.10.dev20250605223455__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/gemma.py +5 -4
- liger_kernel/transformers/model/gemma2.py +7 -4
- liger_kernel/transformers/model/glm4.py +5 -4
- liger_kernel/transformers/model/llama.py +5 -4
- liger_kernel/transformers/model/mistral.py +5 -4
- liger_kernel/transformers/model/mixtral.py +5 -4
- liger_kernel/transformers/model/mllama.py +5 -4
- liger_kernel/transformers/model/olmo2.py +5 -4
- liger_kernel/transformers/model/phi3.py +5 -4
- liger_kernel/transformers/model/qwen2.py +5 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +4 -3
- liger_kernel/transformers/model/qwen2_vl.py +4 -3
- liger_kernel/transformers/model/qwen3_moe.py +5 -4
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/RECORD +19 -19
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605223455.dist-info}/top_level.txt +0 -0
@@ -138,7 +138,7 @@ def lce_forward(
|
|
138
138
|
cache_position: Optional[torch.LongTensor] = None,
|
139
139
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
140
140
|
skip_logits: Optional[bool] = None,
|
141
|
-
**
|
141
|
+
**kwargs,
|
142
142
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
143
143
|
r"""
|
144
144
|
Args:
|
@@ -190,6 +190,7 @@ def lce_forward(
|
|
190
190
|
output_hidden_states=output_hidden_states,
|
191
191
|
return_dict=return_dict,
|
192
192
|
cache_position=cache_position,
|
193
|
+
**kwargs,
|
193
194
|
)
|
194
195
|
|
195
196
|
hidden_states = outputs[0]
|
@@ -197,7 +198,7 @@ def lce_forward(
|
|
197
198
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
198
199
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
199
200
|
|
200
|
-
shift_labels =
|
201
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
201
202
|
logits = None
|
202
203
|
loss = None
|
203
204
|
|
@@ -215,7 +216,7 @@ def lce_forward(
|
|
215
216
|
labels=labels,
|
216
217
|
shift_labels=shift_labels,
|
217
218
|
hidden_size=self.config.hidden_size,
|
218
|
-
**
|
219
|
+
**kwargs,
|
219
220
|
)
|
220
221
|
else:
|
221
222
|
logits = self.lm_head(kept_hidden_states)
|
@@ -224,7 +225,7 @@ def lce_forward(
|
|
224
225
|
logits=logits,
|
225
226
|
labels=labels,
|
226
227
|
vocab_size=self.config.vocab_size,
|
227
|
-
**
|
228
|
+
**kwargs,
|
228
229
|
)
|
229
230
|
|
230
231
|
if not return_dict:
|
@@ -30,6 +30,7 @@ def lce_forward_deprecated(
|
|
30
30
|
output_hidden_states: Optional[bool] = None,
|
31
31
|
return_dict: Optional[bool] = None,
|
32
32
|
cache_position: Optional[torch.LongTensor] = None,
|
33
|
+
**kwargs,
|
33
34
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
34
35
|
r"""
|
35
36
|
Args:
|
@@ -76,6 +77,7 @@ def lce_forward_deprecated(
|
|
76
77
|
output_hidden_states=output_hidden_states,
|
77
78
|
return_dict=return_dict,
|
78
79
|
cache_position=cache_position,
|
80
|
+
**kwargs,
|
79
81
|
)
|
80
82
|
|
81
83
|
hidden_states = outputs[0]
|
@@ -147,7 +149,7 @@ def lce_forward(
|
|
147
149
|
cache_position: Optional[torch.LongTensor] = None,
|
148
150
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
149
151
|
skip_logits: Optional[bool] = None,
|
150
|
-
**
|
152
|
+
**kwargs,
|
151
153
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
152
154
|
r"""
|
153
155
|
Args:
|
@@ -204,6 +206,7 @@ def lce_forward(
|
|
204
206
|
output_hidden_states=output_hidden_states,
|
205
207
|
return_dict=return_dict,
|
206
208
|
cache_position=cache_position,
|
209
|
+
**kwargs,
|
207
210
|
)
|
208
211
|
|
209
212
|
hidden_states = outputs[0]
|
@@ -211,7 +214,7 @@ def lce_forward(
|
|
211
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
212
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
213
216
|
|
214
|
-
shift_labels =
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
215
218
|
logits = None
|
216
219
|
loss = None
|
217
220
|
|
@@ -230,7 +233,7 @@ def lce_forward(
|
|
230
233
|
shift_labels=shift_labels,
|
231
234
|
hidden_size=self.config.hidden_size,
|
232
235
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
233
|
-
**
|
236
|
+
**kwargs,
|
234
237
|
)
|
235
238
|
|
236
239
|
else:
|
@@ -242,7 +245,7 @@ def lce_forward(
|
|
242
245
|
|
243
246
|
loss = None
|
244
247
|
if labels is not None:
|
245
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
248
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
246
249
|
|
247
250
|
if not return_dict:
|
248
251
|
output = (logits,) + outputs[1:]
|
@@ -27,7 +27,7 @@ def lce_forward(
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
29
29
|
skip_logits: Optional[bool] = None,
|
30
|
-
**
|
30
|
+
**kwargs,
|
31
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
32
32
|
r"""
|
33
33
|
Args:
|
@@ -80,6 +80,7 @@ def lce_forward(
|
|
80
80
|
output_hidden_states=output_hidden_states,
|
81
81
|
return_dict=return_dict,
|
82
82
|
cache_position=cache_position,
|
83
|
+
**kwargs,
|
83
84
|
)
|
84
85
|
|
85
86
|
hidden_states = outputs[0]
|
@@ -87,7 +88,7 @@ def lce_forward(
|
|
87
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
88
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
89
90
|
|
90
|
-
shift_labels =
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
91
92
|
logits = None
|
92
93
|
loss = None
|
93
94
|
|
@@ -105,7 +106,7 @@ def lce_forward(
|
|
105
106
|
labels=labels,
|
106
107
|
shift_labels=shift_labels,
|
107
108
|
hidden_size=self.config.hidden_size,
|
108
|
-
**
|
109
|
+
**kwargs,
|
109
110
|
)
|
110
111
|
|
111
112
|
else:
|
@@ -115,7 +116,7 @@ def lce_forward(
|
|
115
116
|
logits=logits,
|
116
117
|
labels=labels,
|
117
118
|
vocab_size=self.config.vocab_size,
|
118
|
-
**
|
119
|
+
**kwargs,
|
119
120
|
)
|
120
121
|
|
121
122
|
return CausalLMOutputWithPast(
|
@@ -152,7 +152,7 @@ def lce_forward(
|
|
152
152
|
cache_position: Optional[torch.LongTensor] = None,
|
153
153
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
154
154
|
skip_logits: Optional[bool] = None,
|
155
|
-
**
|
155
|
+
**kwargs,
|
156
156
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
157
157
|
r"""
|
158
158
|
Args:
|
@@ -205,6 +205,7 @@ def lce_forward(
|
|
205
205
|
output_hidden_states=output_hidden_states,
|
206
206
|
return_dict=return_dict,
|
207
207
|
cache_position=cache_position,
|
208
|
+
**kwargs,
|
208
209
|
)
|
209
210
|
|
210
211
|
hidden_states = outputs[0]
|
@@ -215,7 +216,7 @@ def lce_forward(
|
|
215
216
|
if self.config.pretraining_tp > 1:
|
216
217
|
raise Exception("Liger Kernel does not support pretraining_tp!!")
|
217
218
|
|
218
|
-
shift_labels =
|
219
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
219
220
|
logits = None
|
220
221
|
loss = None
|
221
222
|
# if in training mode, don't materialize logits
|
@@ -233,7 +234,7 @@ def lce_forward(
|
|
233
234
|
hidden_size=self.config.hidden_size,
|
234
235
|
labels=labels,
|
235
236
|
shift_labels=shift_labels,
|
236
|
-
**
|
237
|
+
**kwargs,
|
237
238
|
)
|
238
239
|
|
239
240
|
else:
|
@@ -243,7 +244,7 @@ def lce_forward(
|
|
243
244
|
logits=logits,
|
244
245
|
labels=labels,
|
245
246
|
vocab_size=self.config.vocab_size,
|
246
|
-
**
|
247
|
+
**kwargs,
|
247
248
|
)
|
248
249
|
|
249
250
|
if not return_dict:
|
@@ -28,7 +28,7 @@ def lce_forward(
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
29
29
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
30
30
|
skip_logits: Optional[bool] = None,
|
31
|
-
**
|
31
|
+
**kwargs,
|
32
32
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
33
33
|
r"""
|
34
34
|
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -83,6 +83,7 @@ def lce_forward(
|
|
83
83
|
output_hidden_states=output_hidden_states,
|
84
84
|
return_dict=return_dict,
|
85
85
|
cache_position=cache_position,
|
86
|
+
**kwargs,
|
86
87
|
)
|
87
88
|
|
88
89
|
hidden_states = outputs[0]
|
@@ -90,7 +91,7 @@ def lce_forward(
|
|
90
91
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
91
92
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
92
93
|
|
93
|
-
shift_labels =
|
94
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
94
95
|
loss = None
|
95
96
|
logits = None
|
96
97
|
|
@@ -107,7 +108,7 @@ def lce_forward(
|
|
107
108
|
labels=labels,
|
108
109
|
shift_labels=shift_labels,
|
109
110
|
hidden_size=self.config.hidden_size,
|
110
|
-
**
|
111
|
+
**kwargs,
|
111
112
|
)
|
112
113
|
|
113
114
|
else:
|
@@ -119,7 +120,7 @@ def lce_forward(
|
|
119
120
|
logits=logits,
|
120
121
|
labels=labels,
|
121
122
|
vocab_size=self.config.vocab_size,
|
122
|
-
**
|
123
|
+
**kwargs,
|
123
124
|
)
|
124
125
|
if not return_dict:
|
125
126
|
output = (logits,) + outputs[1:]
|
@@ -157,7 +157,7 @@ def lce_forward(
|
|
157
157
|
cache_position: Optional[torch.LongTensor] = None,
|
158
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
159
159
|
skip_logits: Optional[bool] = None,
|
160
|
-
**
|
160
|
+
**kwargs,
|
161
161
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
162
162
|
r"""
|
163
163
|
Args:
|
@@ -215,6 +215,7 @@ def lce_forward(
|
|
215
215
|
output_router_logits=output_router_logits,
|
216
216
|
return_dict=return_dict,
|
217
217
|
cache_position=cache_position,
|
218
|
+
**kwargs,
|
218
219
|
)
|
219
220
|
|
220
221
|
hidden_states = outputs[0]
|
@@ -222,7 +223,7 @@ def lce_forward(
|
|
222
223
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
223
224
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
224
225
|
|
225
|
-
shift_labels =
|
226
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
226
227
|
logits = None
|
227
228
|
loss = None
|
228
229
|
|
@@ -240,7 +241,7 @@ def lce_forward(
|
|
240
241
|
labels=labels,
|
241
242
|
shift_labels=shift_labels,
|
242
243
|
hidden_size=self.config.hidden_size,
|
243
|
-
**
|
244
|
+
**kwargs,
|
244
245
|
)
|
245
246
|
|
246
247
|
else:
|
@@ -248,7 +249,7 @@ def lce_forward(
|
|
248
249
|
|
249
250
|
loss = None
|
250
251
|
if labels is not None:
|
251
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
252
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
252
253
|
aux_loss = None
|
253
254
|
if output_router_logits:
|
254
255
|
aux_loss = load_balancing_loss_func(
|
@@ -148,7 +148,7 @@ def lce_forward(
|
|
148
148
|
cache_position: Optional[torch.LongTensor] = None,
|
149
149
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
150
150
|
skip_logits: Optional[bool] = None,
|
151
|
-
**
|
151
|
+
**kwargs,
|
152
152
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
153
153
|
r"""
|
154
154
|
Args:
|
@@ -206,6 +206,7 @@ def lce_forward(
|
|
206
206
|
output_hidden_states=output_hidden_states,
|
207
207
|
return_dict=return_dict,
|
208
208
|
cache_position=cache_position,
|
209
|
+
**kwargs,
|
209
210
|
)
|
210
211
|
|
211
212
|
hidden_states = outputs[0]
|
@@ -213,7 +214,7 @@ def lce_forward(
|
|
213
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
214
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
215
216
|
|
216
|
-
shift_labels =
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
217
218
|
logits = None
|
218
219
|
loss = None
|
219
220
|
|
@@ -231,7 +232,7 @@ def lce_forward(
|
|
231
232
|
labels=labels,
|
232
233
|
shift_labels=shift_labels,
|
233
234
|
hidden_size=self.config.hidden_size,
|
234
|
-
**
|
235
|
+
**kwargs,
|
235
236
|
)
|
236
237
|
|
237
238
|
else:
|
@@ -241,7 +242,7 @@ def lce_forward(
|
|
241
242
|
logits=logits,
|
242
243
|
labels=labels,
|
243
244
|
vocab_size=self.config.vocab_size,
|
244
|
-
**
|
245
|
+
**kwargs,
|
245
246
|
)
|
246
247
|
|
247
248
|
if not return_dict:
|
@@ -27,7 +27,7 @@ def lce_forward(
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
29
29
|
skip_logits: Optional[bool] = None,
|
30
|
-
**
|
30
|
+
**kwargs,
|
31
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
32
32
|
r"""
|
33
33
|
Args:
|
@@ -80,6 +80,7 @@ def lce_forward(
|
|
80
80
|
output_hidden_states=output_hidden_states,
|
81
81
|
return_dict=return_dict,
|
82
82
|
cache_position=cache_position,
|
83
|
+
**kwargs,
|
83
84
|
)
|
84
85
|
|
85
86
|
hidden_states = outputs[0]
|
@@ -87,7 +88,7 @@ def lce_forward(
|
|
87
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
88
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
89
90
|
|
90
|
-
shift_labels =
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
91
92
|
logits = None
|
92
93
|
loss = None
|
93
94
|
|
@@ -105,7 +106,7 @@ def lce_forward(
|
|
105
106
|
labels=labels,
|
106
107
|
shift_labels=shift_labels,
|
107
108
|
hidden_size=self.config.hidden_size,
|
108
|
-
**
|
109
|
+
**kwargs,
|
109
110
|
)
|
110
111
|
|
111
112
|
else:
|
@@ -115,7 +116,7 @@ def lce_forward(
|
|
115
116
|
logits=logits,
|
116
117
|
labels=labels,
|
117
118
|
vocab_size=self.config.vocab_size,
|
118
|
-
**
|
119
|
+
**kwargs,
|
119
120
|
)
|
120
121
|
|
121
122
|
return CausalLMOutputWithPast(
|
@@ -137,7 +137,7 @@ def lce_forward(
|
|
137
137
|
cache_position: Optional[torch.LongTensor] = None,
|
138
138
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
139
139
|
skip_logits: Optional[bool] = None,
|
140
|
-
**
|
140
|
+
**kwargs,
|
141
141
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
142
142
|
r"""
|
143
143
|
Args:
|
@@ -203,6 +203,7 @@ def lce_forward(
|
|
203
203
|
output_attentions=output_attentions,
|
204
204
|
output_hidden_states=output_hidden_states,
|
205
205
|
return_dict=return_dict,
|
206
|
+
**kwargs,
|
206
207
|
)
|
207
208
|
|
208
209
|
hidden_states = outputs[0]
|
@@ -210,7 +211,7 @@ def lce_forward(
|
|
210
211
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
211
212
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
212
213
|
|
213
|
-
shift_labels =
|
214
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
214
215
|
logits = None
|
215
216
|
loss = None
|
216
217
|
|
@@ -228,7 +229,7 @@ def lce_forward(
|
|
228
229
|
labels=labels,
|
229
230
|
shift_labels=shift_labels,
|
230
231
|
hidden_size=self.config.hidden_size,
|
231
|
-
**
|
232
|
+
**kwargs,
|
232
233
|
)
|
233
234
|
|
234
235
|
else:
|
@@ -238,7 +239,7 @@ def lce_forward(
|
|
238
239
|
logits=logits,
|
239
240
|
labels=labels,
|
240
241
|
vocab_size=self.config.vocab_size,
|
241
|
-
**
|
242
|
+
**kwargs,
|
242
243
|
)
|
243
244
|
|
244
245
|
if not return_dict:
|
@@ -136,7 +136,7 @@ def lce_forward(
|
|
136
136
|
cache_position: Optional[torch.LongTensor] = None,
|
137
137
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
138
138
|
skip_logits: Optional[bool] = None,
|
139
|
-
**
|
139
|
+
**kwargs,
|
140
140
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
141
141
|
r"""
|
142
142
|
Args:
|
@@ -189,6 +189,7 @@ def lce_forward(
|
|
189
189
|
output_hidden_states=output_hidden_states,
|
190
190
|
return_dict=return_dict,
|
191
191
|
cache_position=cache_position,
|
192
|
+
**kwargs,
|
192
193
|
)
|
193
194
|
|
194
195
|
hidden_states = outputs[0]
|
@@ -196,7 +197,7 @@ def lce_forward(
|
|
196
197
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
197
198
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
198
199
|
|
199
|
-
shift_labels =
|
200
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
200
201
|
logits = None
|
201
202
|
loss = None
|
202
203
|
|
@@ -214,7 +215,7 @@ def lce_forward(
|
|
214
215
|
labels=labels,
|
215
216
|
shift_labels=shift_labels,
|
216
217
|
hidden_size=self.config.hidden_size,
|
217
|
-
**
|
218
|
+
**kwargs,
|
218
219
|
)
|
219
220
|
|
220
221
|
else:
|
@@ -224,7 +225,7 @@ def lce_forward(
|
|
224
225
|
logits=logits,
|
225
226
|
labels=labels,
|
226
227
|
vocab_size=self.config.vocab_size,
|
227
|
-
**
|
228
|
+
**kwargs,
|
228
229
|
)
|
229
230
|
|
230
231
|
return CausalLMOutputWithPast(
|
@@ -31,7 +31,7 @@ def lce_forward(
|
|
31
31
|
cache_position: Optional[torch.LongTensor] = None,
|
32
32
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
33
33
|
skip_logits: Optional[bool] = None,
|
34
|
-
**
|
34
|
+
**kwargs,
|
35
35
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
36
36
|
r"""
|
37
37
|
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -154,11 +154,12 @@ def lce_forward(
|
|
154
154
|
output_hidden_states=output_hidden_states,
|
155
155
|
return_dict=return_dict,
|
156
156
|
cache_position=cache_position,
|
157
|
+
**kwargs,
|
157
158
|
)
|
158
159
|
|
159
160
|
hidden_states = outputs[0]
|
160
161
|
|
161
|
-
shift_labels =
|
162
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
162
163
|
loss = None
|
163
164
|
logits = None
|
164
165
|
|
@@ -175,7 +176,7 @@ def lce_forward(
|
|
175
176
|
labels=labels,
|
176
177
|
shift_labels=shift_labels,
|
177
178
|
hidden_size=self.config.hidden_size,
|
178
|
-
**
|
179
|
+
**kwargs,
|
179
180
|
)
|
180
181
|
else:
|
181
182
|
logits = self.lm_head(hidden_states)
|
@@ -32,7 +32,7 @@ def lce_forward(
|
|
32
32
|
rope_deltas: Optional[torch.LongTensor] = None,
|
33
33
|
cache_position: Optional[torch.LongTensor] = None,
|
34
34
|
skip_logits: Optional[bool] = None,
|
35
|
-
**
|
35
|
+
**kwargs,
|
36
36
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
37
37
|
r"""
|
38
38
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -158,11 +158,12 @@ def lce_forward(
|
|
158
158
|
output_hidden_states=output_hidden_states,
|
159
159
|
return_dict=return_dict,
|
160
160
|
cache_position=cache_position,
|
161
|
+
**kwargs,
|
161
162
|
)
|
162
163
|
|
163
164
|
hidden_states = outputs[0]
|
164
165
|
|
165
|
-
shift_labels =
|
166
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
166
167
|
loss = None
|
167
168
|
logits = None
|
168
169
|
|
@@ -179,7 +180,7 @@ def lce_forward(
|
|
179
180
|
labels=labels,
|
180
181
|
shift_labels=shift_labels,
|
181
182
|
hidden_size=self.config.hidden_size,
|
182
|
-
**
|
183
|
+
**kwargs,
|
183
184
|
)
|
184
185
|
else:
|
185
186
|
logits = self.lm_head(hidden_states)
|
@@ -26,7 +26,7 @@ def lce_forward(
|
|
26
26
|
cache_position: Optional[torch.LongTensor] = None,
|
27
27
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
28
28
|
skip_logits: Optional[bool] = None,
|
29
|
-
**
|
29
|
+
**kwargs,
|
30
30
|
) -> MoeCausalLMOutputWithPast:
|
31
31
|
r"""
|
32
32
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -81,6 +81,7 @@ def lce_forward(
|
|
81
81
|
output_hidden_states=output_hidden_states,
|
82
82
|
output_router_logits=output_router_logits,
|
83
83
|
cache_position=cache_position,
|
84
|
+
**kwargs,
|
84
85
|
)
|
85
86
|
|
86
87
|
hidden_states = outputs.last_hidden_state
|
@@ -88,7 +89,7 @@ def lce_forward(
|
|
88
89
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
89
90
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
90
91
|
|
91
|
-
shift_labels =
|
92
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
92
93
|
logits = None
|
93
94
|
loss = None
|
94
95
|
|
@@ -102,12 +103,12 @@ def lce_forward(
|
|
102
103
|
labels=labels,
|
103
104
|
shift_labels=shift_labels,
|
104
105
|
hidden_size=self.config.hidden_size,
|
105
|
-
**
|
106
|
+
**kwargs,
|
106
107
|
)
|
107
108
|
else: # if in inference model materialize logits
|
108
109
|
logits = self.lm_head(kept_hidden_states)
|
109
110
|
if labels is not None:
|
110
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
111
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
111
112
|
|
112
113
|
aux_loss = None
|
113
114
|
if output_router_logits:
|
@@ -63,31 +63,31 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
|
|
63
63
|
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
64
64
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
65
65
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
-
liger_kernel/transformers/model/gemma.py,sha256=
|
67
|
-
liger_kernel/transformers/model/gemma2.py,sha256=
|
66
|
+
liger_kernel/transformers/model/gemma.py,sha256=gvP-9zZ1e-DQD06qltWmRhiJClJDtkMQL1UrPMMZZGQ,9730
|
67
|
+
liger_kernel/transformers/model/gemma2.py,sha256=ORmzklEAMpk93nToRo4d_ZJbM4ScVE2szczsEL4hw7w,11019
|
68
68
|
liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
|
69
|
-
liger_kernel/transformers/model/glm4.py,sha256=
|
70
|
-
liger_kernel/transformers/model/llama.py,sha256=
|
69
|
+
liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
|
70
|
+
liger_kernel/transformers/model/llama.py,sha256=LcIxVfF0PXXWHBVJa6Ody_5fAtIpxQcI4jC_j-o51fU,12503
|
71
71
|
liger_kernel/transformers/model/llava.py,sha256=ONdpx96AVbbL8QDQvHSm08jMJPz3tzkbeO92IRbAb1A,19270
|
72
72
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
73
|
-
liger_kernel/transformers/model/mistral.py,sha256=
|
74
|
-
liger_kernel/transformers/model/mixtral.py,sha256=
|
75
|
-
liger_kernel/transformers/model/mllama.py,sha256=
|
76
|
-
liger_kernel/transformers/model/olmo2.py,sha256=
|
73
|
+
liger_kernel/transformers/model/mistral.py,sha256=okKkyashfFLfhjIT--f3JY6JHOslOtDI8U1dlpBC2Zs,5565
|
74
|
+
liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
|
75
|
+
liger_kernel/transformers/model/mllama.py,sha256=my29NXk-p6ckQaP8qDIN8e318yI_9mQZHt38MV3SqLY,11280
|
76
|
+
liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6BiGp13g7Uq9E,5259
|
77
77
|
liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
|
78
|
-
liger_kernel/transformers/model/phi3.py,sha256=
|
79
|
-
liger_kernel/transformers/model/qwen2.py,sha256=
|
80
|
-
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=
|
81
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
78
|
+
liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
|
79
|
+
liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
|
80
|
+
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=k6jt1bTCJsKsZVGhTxqIbDzmnL8-B3CpWJOjLazswbo,9203
|
81
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=Cgs7-nPlKFifiDO9gqSI6np4vRUVCKiqoospT_vIi_M,9614
|
82
82
|
liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
|
83
|
-
liger_kernel/transformers/model/qwen3_moe.py,sha256=
|
83
|
+
liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
|
84
84
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
85
85
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
86
86
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
87
87
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
88
|
-
liger_kernel_nightly-0.5.10.
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
88
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/METADATA,sha256=jtKbBFfhtiyDQ7ZfpSZ1EwxGFNTYt0ND_4jL8Xr_pmc,24309
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250605223455.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|