liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -1,146 +1,17 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
4
|
-
from torch.nn import CrossEntropyLoss
|
|
5
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
6
|
-
from transformers.models.phi3.modeling_phi3 import (
|
|
7
|
-
_CONFIG_FOR_DOC,
|
|
8
|
-
PHI3_INPUTS_DOCSTRING,
|
|
9
|
-
)
|
|
10
|
-
from transformers.utils import (
|
|
11
|
-
add_start_docstrings_to_model_forward,
|
|
12
|
-
replace_return_docstrings,
|
|
13
|
-
)
|
|
14
7
|
|
|
15
|
-
from
|
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
17
|
-
)
|
|
8
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
18
9
|
|
|
10
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
19
13
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(
|
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
23
|
-
)
|
|
24
|
-
def lce_forward_deprecated(
|
|
25
|
-
self,
|
|
26
|
-
input_ids: torch.LongTensor = None,
|
|
27
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
28
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
29
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
30
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
31
|
-
labels: Optional[torch.LongTensor] = None,
|
|
32
|
-
use_cache: Optional[bool] = None,
|
|
33
|
-
output_attentions: Optional[bool] = None,
|
|
34
|
-
output_hidden_states: Optional[bool] = None,
|
|
35
|
-
return_dict: Optional[bool] = None,
|
|
36
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
37
|
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
|
-
r"""
|
|
39
|
-
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
44
|
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
45
|
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
46
|
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
47
14
|
|
|
48
|
-
Returns:
|
|
49
|
-
|
|
50
|
-
Example:
|
|
51
|
-
|
|
52
|
-
```python
|
|
53
|
-
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
|
|
54
|
-
|
|
55
|
-
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
|
56
|
-
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
|
57
|
-
|
|
58
|
-
>>> prompt = "This is an example script ."
|
|
59
|
-
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
60
|
-
|
|
61
|
-
>>> # Generate
|
|
62
|
-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
63
|
-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
64
|
-
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
|
65
|
-
```"""
|
|
66
|
-
|
|
67
|
-
output_attentions = (
|
|
68
|
-
output_attentions
|
|
69
|
-
if output_attentions is not None
|
|
70
|
-
else self.config.output_attentions
|
|
71
|
-
)
|
|
72
|
-
output_hidden_states = (
|
|
73
|
-
output_hidden_states
|
|
74
|
-
if output_hidden_states is not None
|
|
75
|
-
else self.config.output_hidden_states
|
|
76
|
-
)
|
|
77
|
-
return_dict = (
|
|
78
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
82
|
-
outputs = self.model(
|
|
83
|
-
input_ids=input_ids,
|
|
84
|
-
attention_mask=attention_mask,
|
|
85
|
-
position_ids=position_ids,
|
|
86
|
-
past_key_values=past_key_values,
|
|
87
|
-
inputs_embeds=inputs_embeds,
|
|
88
|
-
use_cache=use_cache,
|
|
89
|
-
output_attentions=output_attentions,
|
|
90
|
-
output_hidden_states=output_hidden_states,
|
|
91
|
-
return_dict=return_dict,
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
hidden_states = outputs[0]
|
|
95
|
-
|
|
96
|
-
loss = None
|
|
97
|
-
logits = None
|
|
98
|
-
|
|
99
|
-
if self.training and labels is not None:
|
|
100
|
-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
101
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
102
|
-
|
|
103
|
-
# flatten tokens
|
|
104
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
105
|
-
shift_labels = shift_labels.view(-1)
|
|
106
|
-
|
|
107
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
108
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
109
|
-
else:
|
|
110
|
-
logits = self.lm_head(hidden_states)
|
|
111
|
-
|
|
112
|
-
loss = None
|
|
113
|
-
if labels is not None:
|
|
114
|
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
115
|
-
logits = logits.float()
|
|
116
|
-
# Shift so that tokens < n predict n
|
|
117
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
|
118
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
119
|
-
# Flatten the tokens
|
|
120
|
-
loss_fct = CrossEntropyLoss()
|
|
121
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
122
|
-
shift_labels = shift_labels.view(-1)
|
|
123
|
-
# Enable model parallelism
|
|
124
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
|
125
|
-
loss = loss_fct(shift_logits, shift_labels)
|
|
126
|
-
|
|
127
|
-
if not return_dict:
|
|
128
|
-
output = (logits,) + outputs[1:]
|
|
129
|
-
return (loss,) + output if loss is not None else output
|
|
130
|
-
|
|
131
|
-
return CausalLMOutputWithPast(
|
|
132
|
-
loss=loss,
|
|
133
|
-
logits=logits,
|
|
134
|
-
past_key_values=outputs.past_key_values,
|
|
135
|
-
hidden_states=outputs.hidden_states,
|
|
136
|
-
attentions=outputs.attentions,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
141
|
-
@replace_return_docstrings(
|
|
142
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
143
|
-
)
|
|
144
15
|
def lce_forward(
|
|
145
16
|
self,
|
|
146
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -154,121 +25,96 @@ def lce_forward(
|
|
|
154
25
|
output_hidden_states: Optional[bool] = None,
|
|
155
26
|
return_dict: Optional[bool] = None,
|
|
156
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
28
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
160
32
|
r"""
|
|
161
|
-
Args:
|
|
162
|
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
163
|
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
164
|
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
165
|
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
166
|
-
|
|
167
|
-
num_logits_to_keep (`int`, *optional*):
|
|
168
|
-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
169
|
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
170
|
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
|
|
174
33
|
Example:
|
|
175
34
|
|
|
176
35
|
```python
|
|
177
36
|
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
|
|
178
37
|
|
|
179
|
-
>>> model = Phi3ForCausalLM.from_pretrained("
|
|
180
|
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
|
38
|
+
>>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
|
|
39
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
|
|
181
40
|
|
|
182
|
-
>>> prompt = "
|
|
41
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
183
42
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
184
43
|
|
|
185
44
|
>>> # Generate
|
|
186
45
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
187
46
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
188
|
-
|
|
47
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
189
48
|
```"""
|
|
190
49
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
logger = logging.get_logger(__name__)
|
|
194
|
-
|
|
195
|
-
if (
|
|
196
|
-
use_cache
|
|
197
|
-
and self.config.rope_scaling
|
|
198
|
-
and cache_position is not None
|
|
199
|
-
and cache_position[0] == self.config.original_max_position_embeddings
|
|
200
|
-
):
|
|
201
|
-
logger.warning(
|
|
202
|
-
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
output_attentions = (
|
|
206
|
-
output_attentions
|
|
207
|
-
if output_attentions is not None
|
|
208
|
-
else self.config.output_attentions
|
|
209
|
-
)
|
|
50
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
210
51
|
output_hidden_states = (
|
|
211
|
-
output_hidden_states
|
|
212
|
-
if output_hidden_states is not None
|
|
213
|
-
else self.config.output_hidden_states
|
|
214
|
-
)
|
|
215
|
-
return_dict = (
|
|
216
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
52
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
217
53
|
)
|
|
54
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
218
55
|
|
|
219
|
-
|
|
220
|
-
outputs = self.model(
|
|
56
|
+
outputs: BaseModelOutputWithPast = self.model(
|
|
221
57
|
input_ids=input_ids,
|
|
222
58
|
attention_mask=attention_mask,
|
|
223
59
|
position_ids=position_ids,
|
|
224
60
|
past_key_values=past_key_values,
|
|
225
61
|
inputs_embeds=inputs_embeds,
|
|
226
62
|
use_cache=use_cache,
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
return_dict=return_dict,
|
|
63
|
+
cache_position=cache_position,
|
|
64
|
+
**kwargs,
|
|
230
65
|
)
|
|
231
66
|
|
|
232
|
-
hidden_states = outputs
|
|
67
|
+
hidden_states = outputs.last_hidden_state
|
|
68
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
69
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
70
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
233
71
|
|
|
72
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
234
73
|
logits = None
|
|
235
74
|
loss = None
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
75
|
+
token_accuracy = None
|
|
76
|
+
|
|
77
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
78
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
79
|
+
|
|
80
|
+
if skip_logits is None:
|
|
81
|
+
# By default, if in training mode, don't materialize logits
|
|
82
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
83
|
+
|
|
84
|
+
# Compute loss
|
|
85
|
+
if skip_logits:
|
|
86
|
+
result = LigerForCausalLMLoss(
|
|
87
|
+
hidden_states=kept_hidden_states,
|
|
88
|
+
lm_head_weight=self.lm_head.weight,
|
|
89
|
+
labels=labels,
|
|
90
|
+
shift_labels=shift_labels,
|
|
91
|
+
hidden_size=self.config.hidden_size,
|
|
92
|
+
**kwargs,
|
|
93
|
+
)
|
|
94
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
95
|
+
else:
|
|
96
|
+
logits = self.lm_head(kept_hidden_states)
|
|
97
|
+
if labels is not None or shift_labels is not None:
|
|
257
98
|
loss = self.loss_function(
|
|
258
99
|
logits=logits,
|
|
259
100
|
labels=labels,
|
|
101
|
+
shift_labels=shift_labels,
|
|
260
102
|
vocab_size=self.config.vocab_size,
|
|
261
|
-
**
|
|
103
|
+
**kwargs,
|
|
262
104
|
)
|
|
263
105
|
|
|
264
106
|
if not return_dict:
|
|
265
|
-
|
|
266
|
-
|
|
107
|
+
output_tuple = (logits,) + outputs[1:]
|
|
108
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
109
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
110
|
+
return output
|
|
267
111
|
|
|
268
|
-
|
|
112
|
+
# Return custom output class with token_accuracy field
|
|
113
|
+
return LigerCausalLMOutputWithPast(
|
|
269
114
|
loss=loss,
|
|
270
115
|
logits=logits,
|
|
271
116
|
past_key_values=outputs.past_key_values,
|
|
272
117
|
hidden_states=outputs.hidden_states,
|
|
273
118
|
attentions=outputs.attentions,
|
|
119
|
+
token_accuracy=token_accuracy,
|
|
274
120
|
)
|
|
@@ -1,26 +1,20 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
2
5
|
|
|
3
6
|
import torch
|
|
7
|
+
|
|
4
8
|
from torch.nn import CrossEntropyLoss
|
|
5
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
6
|
-
from transformers.
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from transformers.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
16
|
-
LigerFusedLinearCrossEntropyLoss,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(
|
|
22
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
23
|
-
)
|
|
10
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
11
|
+
|
|
12
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
13
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
16
|
+
|
|
17
|
+
|
|
24
18
|
def lce_forward_deprecated(
|
|
25
19
|
self,
|
|
26
20
|
input_ids: torch.LongTensor = None,
|
|
@@ -34,6 +28,7 @@ def lce_forward_deprecated(
|
|
|
34
28
|
output_hidden_states: Optional[bool] = None,
|
|
35
29
|
return_dict: Optional[bool] = None,
|
|
36
30
|
cache_position: Optional[torch.LongTensor] = None,
|
|
31
|
+
skip_logits: Optional[bool] = None,
|
|
37
32
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
33
|
r"""
|
|
39
34
|
Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -63,19 +58,11 @@ def lce_forward_deprecated(
|
|
|
63
58
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
64
59
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
65
60
|
```"""
|
|
66
|
-
output_attentions =
|
|
67
|
-
output_attentions
|
|
68
|
-
if output_attentions is not None
|
|
69
|
-
else self.config.output_attentions
|
|
70
|
-
)
|
|
61
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
71
62
|
output_hidden_states = (
|
|
72
|
-
output_hidden_states
|
|
73
|
-
if output_hidden_states is not None
|
|
74
|
-
else self.config.output_hidden_states
|
|
75
|
-
)
|
|
76
|
-
return_dict = (
|
|
77
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
63
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
78
64
|
)
|
|
65
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
66
|
|
|
80
67
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
81
68
|
outputs = self.model(
|
|
@@ -96,6 +83,13 @@ def lce_forward_deprecated(
|
|
|
96
83
|
loss = None
|
|
97
84
|
logits = None
|
|
98
85
|
|
|
86
|
+
if skip_logits and labels is None:
|
|
87
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
88
|
+
|
|
89
|
+
if skip_logits is None:
|
|
90
|
+
# By default, if in training mode, don't materialize logits
|
|
91
|
+
skip_logits = self.training and labels is not None
|
|
92
|
+
|
|
99
93
|
if self.training and (labels is not None):
|
|
100
94
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
101
95
|
shift_labels = labels[..., 1:].contiguous()
|
|
@@ -136,10 +130,7 @@ def lce_forward_deprecated(
|
|
|
136
130
|
)
|
|
137
131
|
|
|
138
132
|
|
|
139
|
-
@
|
|
140
|
-
@replace_return_docstrings(
|
|
141
|
-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
142
|
-
)
|
|
133
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
143
134
|
def lce_forward(
|
|
144
135
|
self,
|
|
145
136
|
input_ids: torch.LongTensor = None,
|
|
@@ -153,9 +144,10 @@ def lce_forward(
|
|
|
153
144
|
output_hidden_states: Optional[bool] = None,
|
|
154
145
|
return_dict: Optional[bool] = None,
|
|
155
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
147
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
148
|
+
skip_logits: Optional[bool] = None,
|
|
149
|
+
**kwargs,
|
|
150
|
+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
|
|
159
151
|
r"""
|
|
160
152
|
Args:
|
|
161
153
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -163,10 +155,12 @@ def lce_forward(
|
|
|
163
155
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
164
156
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
165
157
|
|
|
166
|
-
|
|
167
|
-
|
|
158
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
159
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
168
160
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
169
161
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
162
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
163
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
170
164
|
|
|
171
165
|
Returns:
|
|
172
166
|
|
|
@@ -187,19 +181,11 @@ def lce_forward(
|
|
|
187
181
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
188
182
|
```"""
|
|
189
183
|
|
|
190
|
-
output_attentions =
|
|
191
|
-
output_attentions
|
|
192
|
-
if output_attentions is not None
|
|
193
|
-
else self.config.output_attentions
|
|
194
|
-
)
|
|
184
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
195
185
|
output_hidden_states = (
|
|
196
|
-
output_hidden_states
|
|
197
|
-
if output_hidden_states is not None
|
|
198
|
-
else self.config.output_hidden_states
|
|
199
|
-
)
|
|
200
|
-
return_dict = (
|
|
201
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
186
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
202
187
|
)
|
|
188
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
203
189
|
|
|
204
190
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
205
191
|
outputs = self.model(
|
|
@@ -213,44 +199,61 @@ def lce_forward(
|
|
|
213
199
|
output_hidden_states=output_hidden_states,
|
|
214
200
|
return_dict=return_dict,
|
|
215
201
|
cache_position=cache_position,
|
|
202
|
+
**kwargs,
|
|
216
203
|
)
|
|
217
204
|
|
|
218
205
|
hidden_states = outputs[0]
|
|
206
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
207
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
208
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
219
209
|
|
|
210
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
220
211
|
logits = None
|
|
221
212
|
loss = None
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
213
|
+
token_accuracy = None
|
|
214
|
+
|
|
215
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
216
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
217
|
+
|
|
218
|
+
if skip_logits is None:
|
|
219
|
+
# By default, if in training mode, don't materialize logits
|
|
220
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
221
|
+
|
|
222
|
+
# Compute loss
|
|
223
|
+
if skip_logits:
|
|
224
|
+
result = LigerForCausalLMLoss(
|
|
225
|
+
hidden_states=kept_hidden_states,
|
|
226
|
+
lm_head_weight=self.lm_head.weight,
|
|
227
|
+
labels=labels,
|
|
228
|
+
shift_labels=shift_labels,
|
|
229
|
+
hidden_size=self.config.hidden_size,
|
|
230
|
+
**kwargs,
|
|
231
|
+
)
|
|
232
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
225
233
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
# flatten tokens
|
|
230
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
231
|
-
shift_labels = shift_labels.view(-1)
|
|
232
|
-
|
|
233
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
234
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
235
|
-
|
|
236
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
237
|
-
if reduction == "sum":
|
|
238
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
239
|
-
|
|
240
|
-
else: # if in inference mode materialize logits
|
|
241
|
-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
242
|
-
if labels is not None:
|
|
234
|
+
else:
|
|
235
|
+
logits = self.lm_head(kept_hidden_states)
|
|
236
|
+
if labels is not None or shift_labels is not None:
|
|
243
237
|
loss = self.loss_function(
|
|
244
238
|
logits=logits,
|
|
245
239
|
labels=labels,
|
|
240
|
+
shift_labels=shift_labels,
|
|
246
241
|
vocab_size=self.config.vocab_size,
|
|
247
|
-
**
|
|
242
|
+
**kwargs,
|
|
248
243
|
)
|
|
249
244
|
|
|
250
|
-
|
|
245
|
+
if not return_dict:
|
|
246
|
+
output_tuple = (logits,) + outputs[1:]
|
|
247
|
+
output = (loss,) + output_tuple if loss is not None else output_tuple
|
|
248
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
249
|
+
return output
|
|
250
|
+
|
|
251
|
+
# Return custom output class with token accuracy field
|
|
252
|
+
return LigerCausalLMOutputWithPast(
|
|
251
253
|
loss=loss,
|
|
252
254
|
logits=logits,
|
|
253
255
|
past_key_values=outputs.past_key_values,
|
|
254
256
|
hidden_states=outputs.hidden_states,
|
|
255
257
|
attentions=outputs.attentions,
|
|
258
|
+
token_accuracy=token_accuracy,
|
|
256
259
|
)
|