liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -7,18 +7,15 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
13
|
-
from transformers.utils import
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
11
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
12
|
|
|
16
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
16
|
+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
19
|
def lce_forward_deprecated(
|
|
23
20
|
self,
|
|
24
21
|
input_ids: torch.LongTensor = None,
|
|
@@ -144,8 +141,7 @@ def lce_forward_deprecated(
|
|
|
144
141
|
)
|
|
145
142
|
|
|
146
143
|
|
|
147
|
-
@
|
|
148
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
144
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
149
145
|
# Ignore copy
|
|
150
146
|
def lce_forward(
|
|
151
147
|
self,
|
|
@@ -161,9 +157,10 @@ def lce_forward(
|
|
|
161
157
|
output_router_logits: Optional[bool] = None,
|
|
162
158
|
return_dict: Optional[bool] = None,
|
|
163
159
|
cache_position: Optional[torch.LongTensor] = None,
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
160
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
161
|
+
skip_logits: Optional[bool] = None,
|
|
162
|
+
**kwargs,
|
|
163
|
+
) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]:
|
|
167
164
|
r"""
|
|
168
165
|
Args:
|
|
169
166
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -171,10 +168,12 @@ def lce_forward(
|
|
|
171
168
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
172
169
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
173
170
|
|
|
174
|
-
|
|
175
|
-
|
|
171
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
172
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
176
173
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
177
174
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
175
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
176
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
178
177
|
|
|
179
178
|
Returns:
|
|
180
179
|
|
|
@@ -218,32 +217,50 @@ def lce_forward(
|
|
|
218
217
|
output_router_logits=output_router_logits,
|
|
219
218
|
return_dict=return_dict,
|
|
220
219
|
cache_position=cache_position,
|
|
220
|
+
**kwargs,
|
|
221
221
|
)
|
|
222
222
|
|
|
223
223
|
hidden_states = outputs[0]
|
|
224
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
225
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
226
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
224
227
|
|
|
228
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
225
229
|
logits = None
|
|
226
230
|
loss = None
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
+
token_accuracy = None
|
|
232
|
+
|
|
233
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
234
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
235
|
+
|
|
236
|
+
if skip_logits is None:
|
|
237
|
+
# By default, if in training mode, don't materialize logits
|
|
238
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
239
|
+
|
|
240
|
+
# Compute loss
|
|
241
|
+
if skip_logits:
|
|
242
|
+
result = LigerForCausalLMLoss(
|
|
243
|
+
hidden_states=kept_hidden_states,
|
|
231
244
|
lm_head_weight=self.lm_head.weight,
|
|
232
245
|
labels=labels,
|
|
246
|
+
shift_labels=shift_labels,
|
|
233
247
|
hidden_size=self.config.hidden_size,
|
|
234
|
-
**
|
|
248
|
+
**kwargs,
|
|
235
249
|
)
|
|
250
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
236
251
|
|
|
237
|
-
else:
|
|
238
|
-
logits = self.lm_head(
|
|
239
|
-
|
|
252
|
+
else:
|
|
253
|
+
logits = self.lm_head(kept_hidden_states)
|
|
254
|
+
|
|
255
|
+
loss = None
|
|
256
|
+
if labels is not None or shift_labels is not None:
|
|
240
257
|
loss = self.loss_function(
|
|
241
258
|
logits=logits,
|
|
242
259
|
labels=labels,
|
|
243
|
-
|
|
244
|
-
|
|
260
|
+
shift_labels=shift_labels,
|
|
261
|
+
vocab_size=self.vocab_size,
|
|
262
|
+
**kwargs,
|
|
245
263
|
)
|
|
246
|
-
|
|
247
264
|
aux_loss = None
|
|
248
265
|
if output_router_logits:
|
|
249
266
|
aux_loss = load_balancing_loss_func(
|
|
@@ -256,17 +273,21 @@ def lce_forward(
|
|
|
256
273
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
257
274
|
|
|
258
275
|
if not return_dict:
|
|
259
|
-
|
|
276
|
+
output_tuple = (logits,) + outputs[1:]
|
|
260
277
|
if output_router_logits:
|
|
261
|
-
|
|
262
|
-
|
|
278
|
+
output_tuple = (aux_loss,) + output_tuple
|
|
279
|
+
if token_accuracy is not None:
|
|
280
|
+
output_tuple = output_tuple + (token_accuracy,)
|
|
281
|
+
return (loss,) + output_tuple if loss is not None else output_tuple
|
|
263
282
|
|
|
264
|
-
|
|
283
|
+
# Return custom output class with token_accuracy field
|
|
284
|
+
return LigerMoeCausalLMOutputWithPast(
|
|
265
285
|
loss=loss,
|
|
266
286
|
aux_loss=aux_loss,
|
|
267
287
|
logits=logits,
|
|
268
288
|
past_key_values=outputs.past_key_values,
|
|
269
289
|
hidden_states=outputs.hidden_states,
|
|
270
290
|
attentions=outputs.attentions,
|
|
271
|
-
router_logits=outputs.router_logits,
|
|
291
|
+
router_logits=outputs.router_logits if return_dict else outputs[-1],
|
|
292
|
+
token_accuracy=token_accuracy,
|
|
272
293
|
)
|
|
@@ -8,16 +8,14 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
11
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
12
|
|
|
15
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
16
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
16
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
21
19
|
def lce_forward_deprecated(
|
|
22
20
|
self,
|
|
23
21
|
input_ids: torch.LongTensor = None,
|
|
@@ -133,8 +131,7 @@ def lce_forward_deprecated(
|
|
|
133
131
|
)
|
|
134
132
|
|
|
135
133
|
|
|
136
|
-
@
|
|
137
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
134
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
138
135
|
def lce_forward(
|
|
139
136
|
self,
|
|
140
137
|
input_ids: torch.LongTensor = None,
|
|
@@ -151,9 +148,10 @@ def lce_forward(
|
|
|
151
148
|
output_hidden_states: Optional[bool] = None,
|
|
152
149
|
return_dict: Optional[bool] = None,
|
|
153
150
|
cache_position: Optional[torch.LongTensor] = None,
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
151
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
152
|
+
skip_logits: Optional[bool] = None,
|
|
153
|
+
**kwargs,
|
|
154
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
157
155
|
r"""
|
|
158
156
|
Args:
|
|
159
157
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -161,10 +159,12 @@ def lce_forward(
|
|
|
161
159
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
162
160
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
163
161
|
|
|
164
|
-
|
|
165
|
-
|
|
162
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
163
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
166
164
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
167
165
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
166
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
167
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
168
168
|
|
|
169
169
|
Returns:
|
|
170
170
|
|
|
@@ -192,7 +192,9 @@ def lce_forward(
|
|
|
192
192
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
193
193
|
)
|
|
194
194
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
195
|
-
|
|
195
|
+
# Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
|
|
196
|
+
# but preserve it for loss function calls
|
|
197
|
+
model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
|
|
196
198
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
197
199
|
outputs = self.model(
|
|
198
200
|
input_ids=input_ids,
|
|
@@ -208,40 +210,60 @@ def lce_forward(
|
|
|
208
210
|
output_hidden_states=output_hidden_states,
|
|
209
211
|
return_dict=return_dict,
|
|
210
212
|
cache_position=cache_position,
|
|
213
|
+
**model_kwargs,
|
|
211
214
|
)
|
|
212
215
|
|
|
213
216
|
hidden_states = outputs[0]
|
|
217
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
218
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
219
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
214
220
|
|
|
221
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
215
222
|
logits = None
|
|
216
223
|
loss = None
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
224
|
+
token_accuracy = None
|
|
225
|
+
|
|
226
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
227
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
228
|
+
|
|
229
|
+
if skip_logits is None:
|
|
230
|
+
# By default, if in training mode, don't materialize logits
|
|
231
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
232
|
+
|
|
233
|
+
if skip_logits:
|
|
234
|
+
result = LigerForCausalLMLoss(
|
|
235
|
+
hidden_states=kept_hidden_states,
|
|
221
236
|
lm_head_weight=self.lm_head.weight,
|
|
222
237
|
labels=labels,
|
|
238
|
+
shift_labels=shift_labels,
|
|
223
239
|
hidden_size=self.config.hidden_size,
|
|
224
|
-
**
|
|
240
|
+
**kwargs,
|
|
225
241
|
)
|
|
242
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
226
243
|
|
|
227
|
-
else:
|
|
228
|
-
logits = self.lm_head(
|
|
229
|
-
if labels is not None:
|
|
244
|
+
else:
|
|
245
|
+
logits = self.lm_head(kept_hidden_states)
|
|
246
|
+
if labels is not None or shift_labels is not None:
|
|
230
247
|
loss = self.loss_function(
|
|
231
248
|
logits=logits,
|
|
232
249
|
labels=labels,
|
|
250
|
+
shift_labels=shift_labels,
|
|
233
251
|
vocab_size=self.config.vocab_size,
|
|
234
|
-
**
|
|
252
|
+
**kwargs,
|
|
235
253
|
)
|
|
236
254
|
|
|
237
255
|
if not return_dict:
|
|
238
256
|
output = (logits,) + outputs[1:]
|
|
239
|
-
|
|
257
|
+
output = (loss,) + output if loss is not None else output
|
|
258
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
259
|
+
return output
|
|
240
260
|
|
|
241
|
-
|
|
261
|
+
# Return custom output class with token_accuracy field
|
|
262
|
+
return LigerCausalLMOutputWithPast(
|
|
242
263
|
loss=loss,
|
|
243
264
|
logits=logits,
|
|
244
265
|
past_key_values=outputs.past_key_values,
|
|
245
266
|
hidden_states=outputs.hidden_states,
|
|
246
267
|
attentions=outputs.attentions,
|
|
268
|
+
token_accuracy=token_accuracy,
|
|
247
269
|
)
|
|
@@ -5,17 +5,14 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.
|
|
9
|
-
from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
8
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
13
9
|
|
|
14
10
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
15
13
|
|
|
16
14
|
|
|
17
|
-
@
|
|
18
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
15
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
16
|
def lce_forward(
|
|
20
17
|
self,
|
|
21
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -29,9 +26,10 @@ def lce_forward(
|
|
|
29
26
|
output_hidden_states: Optional[bool] = None,
|
|
30
27
|
return_dict: Optional[bool] = None,
|
|
31
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
29
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
30
|
+
skip_logits: Optional[bool] = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
35
33
|
r"""
|
|
36
34
|
Args:
|
|
37
35
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -39,10 +37,12 @@ def lce_forward(
|
|
|
39
37
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
40
38
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
41
39
|
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
41
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
44
42
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
45
43
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
44
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
45
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
46
46
|
|
|
47
47
|
Returns:
|
|
48
48
|
|
|
@@ -81,36 +81,61 @@ def lce_forward(
|
|
|
81
81
|
output_hidden_states=output_hidden_states,
|
|
82
82
|
return_dict=return_dict,
|
|
83
83
|
cache_position=cache_position,
|
|
84
|
+
**kwargs,
|
|
84
85
|
)
|
|
85
86
|
|
|
86
87
|
hidden_states = outputs[0]
|
|
88
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
89
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
90
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
87
91
|
|
|
92
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
88
93
|
logits = None
|
|
89
94
|
loss = None
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
95
|
+
token_accuracy = None
|
|
96
|
+
|
|
97
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
98
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
99
|
+
|
|
100
|
+
if skip_logits is None:
|
|
101
|
+
# By default, if in training mode, don't materialize logits
|
|
102
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
103
|
+
|
|
104
|
+
# Compute loss
|
|
105
|
+
if skip_logits:
|
|
106
|
+
result = LigerForCausalLMLoss(
|
|
107
|
+
hidden_states=kept_hidden_states,
|
|
94
108
|
lm_head_weight=self.lm_head.weight,
|
|
95
109
|
labels=labels,
|
|
110
|
+
shift_labels=shift_labels,
|
|
96
111
|
hidden_size=self.config.hidden_size,
|
|
97
|
-
**
|
|
112
|
+
**kwargs,
|
|
98
113
|
)
|
|
114
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
99
115
|
|
|
100
|
-
else:
|
|
101
|
-
logits = self.lm_head(
|
|
102
|
-
if labels is not None:
|
|
116
|
+
else:
|
|
117
|
+
logits = self.lm_head(kept_hidden_states)
|
|
118
|
+
if labels is not None or shift_labels is not None:
|
|
103
119
|
loss = self.loss_function(
|
|
104
120
|
logits=logits,
|
|
105
121
|
labels=labels,
|
|
122
|
+
shift_labels=shift_labels,
|
|
106
123
|
vocab_size=self.config.vocab_size,
|
|
107
|
-
**
|
|
124
|
+
**kwargs,
|
|
108
125
|
)
|
|
109
126
|
|
|
110
|
-
|
|
127
|
+
if not return_dict:
|
|
128
|
+
output = (logits,) + outputs[1:]
|
|
129
|
+
output = ((loss,) + output) if loss is not None else output
|
|
130
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
131
|
+
return output
|
|
132
|
+
|
|
133
|
+
# Return custom output class with token_accuracy field
|
|
134
|
+
return LigerCausalLMOutputWithPast(
|
|
111
135
|
loss=loss,
|
|
112
136
|
logits=logits,
|
|
113
137
|
past_key_values=outputs.past_key_values,
|
|
114
138
|
hidden_states=outputs.hidden_states,
|
|
115
139
|
attentions=outputs.attentions,
|
|
140
|
+
token_accuracy=token_accuracy,
|
|
116
141
|
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
9
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
13
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
17
|
+
def lce_forward(
|
|
18
|
+
self,
|
|
19
|
+
input_ids: torch.LongTensor = None,
|
|
20
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
21
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
22
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
23
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
24
|
+
labels: Optional[torch.LongTensor] = None,
|
|
25
|
+
use_cache: Optional[bool] = None,
|
|
26
|
+
output_attentions: Optional[bool] = None,
|
|
27
|
+
output_hidden_states: Optional[bool] = None,
|
|
28
|
+
return_dict: Optional[bool] = None,
|
|
29
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
30
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
31
|
+
skip_logits: Optional[bool] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
34
|
+
r"""
|
|
35
|
+
Args:
|
|
36
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
37
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
38
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
39
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
40
|
+
|
|
41
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
42
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
43
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
44
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
45
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
46
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
>>> from transformers import AutoTokenizer, Olmo3ForCausalLM
|
|
54
|
+
|
|
55
|
+
>>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
|
|
56
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
|
|
57
|
+
|
|
58
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
59
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
60
|
+
|
|
61
|
+
>>> # Generate
|
|
62
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
63
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
64
|
+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
68
|
+
output_hidden_states = (
|
|
69
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
70
|
+
)
|
|
71
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
72
|
+
|
|
73
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
74
|
+
outputs: BaseModelOutputWithPast = self.model(
|
|
75
|
+
input_ids=input_ids,
|
|
76
|
+
attention_mask=attention_mask,
|
|
77
|
+
position_ids=position_ids,
|
|
78
|
+
past_key_values=past_key_values,
|
|
79
|
+
inputs_embeds=inputs_embeds,
|
|
80
|
+
use_cache=use_cache,
|
|
81
|
+
output_attentions=output_attentions,
|
|
82
|
+
output_hidden_states=output_hidden_states,
|
|
83
|
+
return_dict=return_dict,
|
|
84
|
+
cache_position=cache_position,
|
|
85
|
+
**kwargs,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
hidden_states = outputs[0]
|
|
89
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
90
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
91
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
92
|
+
|
|
93
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
94
|
+
logits = None
|
|
95
|
+
loss = None
|
|
96
|
+
token_accuracy = None
|
|
97
|
+
|
|
98
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
99
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
100
|
+
|
|
101
|
+
if skip_logits is None:
|
|
102
|
+
# By default, if in training mode, don't materialize logits
|
|
103
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
104
|
+
|
|
105
|
+
# Compute loss
|
|
106
|
+
if skip_logits:
|
|
107
|
+
result = LigerForCausalLMLoss(
|
|
108
|
+
hidden_states=kept_hidden_states,
|
|
109
|
+
lm_head_weight=self.lm_head.weight,
|
|
110
|
+
labels=labels,
|
|
111
|
+
shift_labels=shift_labels,
|
|
112
|
+
hidden_size=self.config.hidden_size,
|
|
113
|
+
**kwargs,
|
|
114
|
+
)
|
|
115
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
logits = self.lm_head(kept_hidden_states)
|
|
119
|
+
if labels is not None or shift_labels is not None:
|
|
120
|
+
loss = self.loss_function(
|
|
121
|
+
logits=logits,
|
|
122
|
+
labels=labels,
|
|
123
|
+
shift_labels=shift_labels,
|
|
124
|
+
vocab_size=self.config.vocab_size,
|
|
125
|
+
**kwargs,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if not return_dict:
|
|
129
|
+
output = (logits,) + outputs[1:]
|
|
130
|
+
output = ((loss,) + output) if loss is not None else output
|
|
131
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
132
|
+
return output
|
|
133
|
+
|
|
134
|
+
# Return custom output class with token_accuracy field
|
|
135
|
+
return LigerCausalLMOutputWithPast(
|
|
136
|
+
loss=loss,
|
|
137
|
+
logits=logits,
|
|
138
|
+
past_key_values=outputs.past_key_values,
|
|
139
|
+
hidden_states=outputs.hidden_states,
|
|
140
|
+
attentions=outputs.attentions,
|
|
141
|
+
token_accuracy=token_accuracy,
|
|
142
|
+
)
|