liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,146 +1,17 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
4
- from torch.nn import CrossEntropyLoss
5
- from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.phi3.modeling_phi3 import (
7
- _CONFIG_FOR_DOC,
8
- PHI3_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
14
7
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
18
9
 
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
19
13
 
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
24
- def lce_forward_deprecated(
25
- self,
26
- input_ids: torch.LongTensor = None,
27
- attention_mask: Optional[torch.Tensor] = None,
28
- position_ids: Optional[torch.LongTensor] = None,
29
- past_key_values: Optional[List[torch.FloatTensor]] = None,
30
- inputs_embeds: Optional[torch.FloatTensor] = None,
31
- labels: Optional[torch.LongTensor] = None,
32
- use_cache: Optional[bool] = None,
33
- output_attentions: Optional[bool] = None,
34
- output_hidden_states: Optional[bool] = None,
35
- return_dict: Optional[bool] = None,
36
- cache_position: Optional[torch.LongTensor] = None,
37
- ) -> Union[Tuple, CausalLMOutputWithPast]:
38
- r"""
39
- Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
40
-
41
-
42
- Args:
43
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
44
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
45
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
46
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
47
14
 
48
- Returns:
49
-
50
- Example:
51
-
52
- ```python
53
- >>> from transformers import AutoTokenizer, Phi3ForCausalLM
54
-
55
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
56
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
57
-
58
- >>> prompt = "This is an example script ."
59
- >>> inputs = tokenizer(prompt, return_tensors="pt")
60
-
61
- >>> # Generate
62
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
65
- ```"""
66
-
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self.config.output_attentions
71
- )
72
- output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self.config.output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self.config.use_return_dict
79
- )
80
-
81
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
- outputs = self.model(
83
- input_ids=input_ids,
84
- attention_mask=attention_mask,
85
- position_ids=position_ids,
86
- past_key_values=past_key_values,
87
- inputs_embeds=inputs_embeds,
88
- use_cache=use_cache,
89
- output_attentions=output_attentions,
90
- output_hidden_states=output_hidden_states,
91
- return_dict=return_dict,
92
- )
93
-
94
- hidden_states = outputs[0]
95
-
96
- loss = None
97
- logits = None
98
-
99
- if self.training and labels is not None:
100
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
101
- shift_labels = labels[..., 1:].contiguous()
102
-
103
- # flatten tokens
104
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
- shift_labels = shift_labels.view(-1)
106
-
107
- lce = LigerFusedLinearCrossEntropyLoss()
108
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
109
- else:
110
- logits = self.lm_head(hidden_states)
111
-
112
- loss = None
113
- if labels is not None:
114
- # Upcast to float if we need to compute the loss to avoid potential precision issues
115
- logits = logits.float()
116
- # Shift so that tokens < n predict n
117
- shift_logits = logits[..., :-1, :].contiguous()
118
- shift_labels = labels[..., 1:].contiguous()
119
- # Flatten the tokens
120
- loss_fct = CrossEntropyLoss()
121
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
- shift_labels = shift_labels.view(-1)
123
- # Enable model parallelism
124
- shift_labels = shift_labels.to(shift_logits.device)
125
- loss = loss_fct(shift_logits, shift_labels)
126
-
127
- if not return_dict:
128
- output = (logits,) + outputs[1:]
129
- return (loss,) + output if loss is not None else output
130
-
131
- return CausalLMOutputWithPast(
132
- loss=loss,
133
- logits=logits,
134
- past_key_values=outputs.past_key_values,
135
- hidden_states=outputs.hidden_states,
136
- attentions=outputs.attentions,
137
- )
138
-
139
-
140
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(
142
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
- )
144
15
  def lce_forward(
145
16
  self,
146
17
  input_ids: torch.LongTensor = None,
@@ -154,121 +25,96 @@ def lce_forward(
154
25
  output_hidden_states: Optional[bool] = None,
155
26
  return_dict: Optional[bool] = None,
156
27
  cache_position: Optional[torch.LongTensor] = None,
157
- num_logits_to_keep: int = 0,
158
- **loss_kwargs,
159
- ) -> Union[Tuple, CausalLMOutputWithPast]:
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
160
32
  r"""
161
- Args:
162
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
163
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
164
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
165
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
166
-
167
- num_logits_to_keep (`int`, *optional*):
168
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
169
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
170
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
171
-
172
- Returns:
173
-
174
33
  Example:
175
34
 
176
35
  ```python
177
36
  >>> from transformers import AutoTokenizer, Phi3ForCausalLM
178
37
 
179
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
180
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
38
+ >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
39
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
181
40
 
182
- >>> prompt = "This is an example script ."
41
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
183
42
  >>> inputs = tokenizer(prompt, return_tensors="pt")
184
43
 
185
44
  >>> # Generate
186
45
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
187
46
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
188
- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
47
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
189
48
  ```"""
190
49
 
191
- from transformers.models.phi3.modeling_phi3 import logging
192
-
193
- logger = logging.get_logger(__name__)
194
-
195
- if (
196
- use_cache
197
- and self.config.rope_scaling
198
- and cache_position is not None
199
- and cache_position[0] == self.config.original_max_position_embeddings
200
- ):
201
- logger.warning(
202
- f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
- )
204
-
205
- output_attentions = (
206
- output_attentions
207
- if output_attentions is not None
208
- else self.config.output_attentions
209
- )
50
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
51
  output_hidden_states = (
211
- output_hidden_states
212
- if output_hidden_states is not None
213
- else self.config.output_hidden_states
214
- )
215
- return_dict = (
216
- return_dict if return_dict is not None else self.config.use_return_dict
52
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
217
53
  )
54
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
218
55
 
219
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
- outputs = self.model(
56
+ outputs: BaseModelOutputWithPast = self.model(
221
57
  input_ids=input_ids,
222
58
  attention_mask=attention_mask,
223
59
  position_ids=position_ids,
224
60
  past_key_values=past_key_values,
225
61
  inputs_embeds=inputs_embeds,
226
62
  use_cache=use_cache,
227
- output_attentions=output_attentions,
228
- output_hidden_states=output_hidden_states,
229
- return_dict=return_dict,
63
+ cache_position=cache_position,
64
+ **kwargs,
230
65
  )
231
66
 
232
- hidden_states = outputs[0]
67
+ hidden_states = outputs.last_hidden_state
68
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
69
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
70
+ kept_hidden_states = hidden_states[:, slice_indices, :]
233
71
 
72
+ shift_labels = kwargs.pop("shift_labels", None)
234
73
  logits = None
235
74
  loss = None
236
- # if in training mode, don't materialize logits
237
- if self.training and (labels is not None):
238
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
-
240
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
- shift_labels = labels[..., 1:].contiguous()
242
-
243
- # flatten tokens
244
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
- shift_labels = shift_labels.view(-1)
246
-
247
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249
-
250
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
- if reduction == "sum":
252
- loss /= loss_kwargs["num_items_in_batch"]
253
-
254
- else: # if in inference mode materialize logits
255
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
- if labels is not None:
75
+ token_accuracy = None
76
+
77
+ if skip_logits and labels is None and shift_labels is None:
78
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
79
+
80
+ if skip_logits is None:
81
+ # By default, if in training mode, don't materialize logits
82
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
83
+
84
+ # Compute loss
85
+ if skip_logits:
86
+ result = LigerForCausalLMLoss(
87
+ hidden_states=kept_hidden_states,
88
+ lm_head_weight=self.lm_head.weight,
89
+ labels=labels,
90
+ shift_labels=shift_labels,
91
+ hidden_size=self.config.hidden_size,
92
+ **kwargs,
93
+ )
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
95
+ else:
96
+ logits = self.lm_head(kept_hidden_states)
97
+ if labels is not None or shift_labels is not None:
257
98
  loss = self.loss_function(
258
99
  logits=logits,
259
100
  labels=labels,
101
+ shift_labels=shift_labels,
260
102
  vocab_size=self.config.vocab_size,
261
- **loss_kwargs,
103
+ **kwargs,
262
104
  )
263
105
 
264
106
  if not return_dict:
265
- output = (logits,) + outputs[1:]
266
- return (loss,) + output if loss is not None else output
107
+ output_tuple = (logits,) + outputs[1:]
108
+ output = (loss,) + output_tuple if loss is not None else output_tuple
109
+ output = output + (token_accuracy,) if token_accuracy is not None else output
110
+ return output
267
111
 
268
- return CausalLMOutputWithPast(
112
+ # Return custom output class with token_accuracy field
113
+ return LigerCausalLMOutputWithPast(
269
114
  loss=loss,
270
115
  logits=logits,
271
116
  past_key_values=outputs.past_key_values,
272
117
  hidden_states=outputs.hidden_states,
273
118
  attentions=outputs.attentions,
119
+ token_accuracy=token_accuracy,
274
120
  )
@@ -1,26 +1,20 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
- from transformers.models.qwen2.modeling_qwen2 import (
7
- _CONFIG_FOR_DOC,
8
- QWEN2_INPUTS_DOCSTRING,
9
- )
10
- from transformers.utils import (
11
- add_start_docstrings_to_model_forward,
12
- replace_return_docstrings,
13
- )
14
-
15
- from liger_kernel.transformers.fused_linear_cross_entropy import (
16
- LigerFusedLinearCrossEntropyLoss,
17
- )
18
-
19
-
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(
22
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
- )
10
+ from transformers.utils.deprecation import deprecate_kwarg
11
+
12
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
+
17
+
24
18
  def lce_forward_deprecated(
25
19
  self,
26
20
  input_ids: torch.LongTensor = None,
@@ -34,6 +28,7 @@ def lce_forward_deprecated(
34
28
  output_hidden_states: Optional[bool] = None,
35
29
  return_dict: Optional[bool] = None,
36
30
  cache_position: Optional[torch.LongTensor] = None,
31
+ skip_logits: Optional[bool] = None,
37
32
  ) -> Union[Tuple, CausalLMOutputWithPast]:
38
33
  r"""
39
34
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -63,19 +58,11 @@ def lce_forward_deprecated(
63
58
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
59
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
65
60
  ```"""
66
- output_attentions = (
67
- output_attentions
68
- if output_attentions is not None
69
- else self.config.output_attentions
70
- )
61
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
62
  output_hidden_states = (
72
- output_hidden_states
73
- if output_hidden_states is not None
74
- else self.config.output_hidden_states
75
- )
76
- return_dict = (
77
- return_dict if return_dict is not None else self.config.use_return_dict
63
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
64
  )
65
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
66
 
80
67
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
68
  outputs = self.model(
@@ -96,6 +83,13 @@ def lce_forward_deprecated(
96
83
  loss = None
97
84
  logits = None
98
85
 
86
+ if skip_logits and labels is None:
87
+ raise ValueError("skip_logits is True, but labels is None")
88
+
89
+ if skip_logits is None:
90
+ # By default, if in training mode, don't materialize logits
91
+ skip_logits = self.training and labels is not None
92
+
99
93
  if self.training and (labels is not None):
100
94
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
101
95
  shift_labels = labels[..., 1:].contiguous()
@@ -136,10 +130,7 @@ def lce_forward_deprecated(
136
130
  )
137
131
 
138
132
 
139
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(
141
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
- )
133
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
143
134
  def lce_forward(
144
135
  self,
145
136
  input_ids: torch.LongTensor = None,
@@ -153,9 +144,10 @@ def lce_forward(
153
144
  output_hidden_states: Optional[bool] = None,
154
145
  return_dict: Optional[bool] = None,
155
146
  cache_position: Optional[torch.LongTensor] = None,
156
- num_logits_to_keep: int = 0,
157
- **loss_kwargs,
158
- ) -> Union[Tuple, CausalLMOutputWithPast]:
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
150
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
159
151
  r"""
160
152
  Args:
161
153
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -163,10 +155,12 @@ def lce_forward(
163
155
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
164
156
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
165
157
 
166
- num_logits_to_keep (`int`, *optional*):
167
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
158
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
159
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
168
160
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
169
161
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
162
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
163
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
170
164
 
171
165
  Returns:
172
166
 
@@ -187,19 +181,11 @@ def lce_forward(
187
181
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
182
  ```"""
189
183
 
190
- output_attentions = (
191
- output_attentions
192
- if output_attentions is not None
193
- else self.config.output_attentions
194
- )
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
195
185
  output_hidden_states = (
196
- output_hidden_states
197
- if output_hidden_states is not None
198
- else self.config.output_hidden_states
199
- )
200
- return_dict = (
201
- return_dict if return_dict is not None else self.config.use_return_dict
186
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
187
  )
188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
189
 
204
190
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
191
  outputs = self.model(
@@ -213,44 +199,61 @@ def lce_forward(
213
199
  output_hidden_states=output_hidden_states,
214
200
  return_dict=return_dict,
215
201
  cache_position=cache_position,
202
+ **kwargs,
216
203
  )
217
204
 
218
205
  hidden_states = outputs[0]
206
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
207
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
208
+ kept_hidden_states = hidden_states[:, slice_indices, :]
219
209
 
210
+ shift_labels = kwargs.pop("shift_labels", None)
220
211
  logits = None
221
212
  loss = None
222
- # if in training mode, don't materialize logits
223
- if self.training and (labels is not None):
224
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
213
+ token_accuracy = None
214
+
215
+ if skip_logits and labels is None and shift_labels is None:
216
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
217
+
218
+ if skip_logits is None:
219
+ # By default, if in training mode, don't materialize logits
220
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
221
+
222
+ # Compute loss
223
+ if skip_logits:
224
+ result = LigerForCausalLMLoss(
225
+ hidden_states=kept_hidden_states,
226
+ lm_head_weight=self.lm_head.weight,
227
+ labels=labels,
228
+ shift_labels=shift_labels,
229
+ hidden_size=self.config.hidden_size,
230
+ **kwargs,
231
+ )
232
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
225
233
 
226
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
227
- shift_labels = labels[..., 1:].contiguous()
228
-
229
- # flatten tokens
230
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
231
- shift_labels = shift_labels.view(-1)
232
-
233
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
234
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
235
-
236
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
237
- if reduction == "sum":
238
- loss /= loss_kwargs["num_items_in_batch"]
239
-
240
- else: # if in inference mode materialize logits
241
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
242
- if labels is not None:
234
+ else:
235
+ logits = self.lm_head(kept_hidden_states)
236
+ if labels is not None or shift_labels is not None:
243
237
  loss = self.loss_function(
244
238
  logits=logits,
245
239
  labels=labels,
240
+ shift_labels=shift_labels,
246
241
  vocab_size=self.config.vocab_size,
247
- **loss_kwargs,
242
+ **kwargs,
248
243
  )
249
244
 
250
- return CausalLMOutputWithPast(
245
+ if not return_dict:
246
+ output_tuple = (logits,) + outputs[1:]
247
+ output = (loss,) + output_tuple if loss is not None else output_tuple
248
+ output = output + (token_accuracy,) if token_accuracy is not None else output
249
+ return output
250
+
251
+ # Return custom output class with token accuracy field
252
+ return LigerCausalLMOutputWithPast(
251
253
  loss=loss,
252
254
  logits=logits,
253
255
  past_key_values=outputs.past_key_values,
254
256
  hidden_states=outputs.hidden_states,
255
257
  attentions=outputs.attentions,
258
+ token_accuracy=token_accuracy,
256
259
  )