liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -7,22 +7,19 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
- from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
- from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
10
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
11
  from transformers.utils import is_torchdynamo_compiling
15
12
  from transformers.utils import logging
16
- from transformers.utils import replace_return_docstrings
17
13
  from transformers.utils.deprecation import deprecate_kwarg
18
14
 
19
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
18
+ from liger_kernel.transformers.model.output_classes import LigerPaliGemmaCausalLMOutputWithPast
20
19
 
21
20
  logger = logging.get_logger(__name__)
22
21
 
23
22
 
24
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
23
  def lce_forward_deprecated(
27
24
  self,
28
25
  input_ids: torch.LongTensor = None,
@@ -206,8 +203,6 @@ def lce_forward_deprecated(
206
203
 
207
204
 
208
205
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
206
  def lce_forward(
212
207
  self,
213
208
  input_ids: torch.LongTensor = None,
@@ -224,8 +219,9 @@ def lce_forward(
224
219
  output_hidden_states: Optional[bool] = None,
225
220
  return_dict: Optional[bool] = None,
226
221
  logits_to_keep: Union[int, torch.Tensor] = 0,
222
+ skip_logits: Optional[bool] = None,
227
223
  **lm_kwargs,
228
- ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
224
+ ) -> Union[Tuple, LigerPaliGemmaCausalLMOutputWithPast]:
229
225
  r"""
230
226
  Args:
231
227
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -334,12 +330,20 @@ def lce_forward(
334
330
  **lm_kwargs,
335
331
  )
336
332
 
333
+ shift_labels = lm_kwargs.pop("shift_labels", None)
337
334
  hidden_states = outputs[0]
338
335
 
339
336
  loss = None
340
337
  logits = None
338
+ token_accuracy = None
341
339
 
342
- if self.training and (labels is not None):
340
+ if skip_logits and labels is None:
341
+ raise ValueError("skip_logits is True, but labels is None")
342
+
343
+ if skip_logits is None:
344
+ skip_logits = self.training and (labels is not None)
345
+
346
+ if skip_logits:
343
347
  shift_hidden_states = hidden_states[..., :-1, :]
344
348
  shift_labels = labels[..., 1:]
345
349
 
@@ -359,8 +363,16 @@ def lce_forward(
359
363
  shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
360
364
  shift_labels = shift_labels.view(-1).to(hidden_device)
361
365
 
362
- lce = LigerFusedLinearCrossEntropyLoss()
363
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
366
+ # Use LigerForCausalLMLoss with accuracy support and pass already shifted labels
367
+ result = LigerForCausalLMLoss(
368
+ hidden_states=shift_hidden_states,
369
+ lm_head_weight=self.language_model.lm_head.weight,
370
+ labels=None,
371
+ shift_labels=shift_labels,
372
+ hidden_size=self.config.text_config.hidden_size,
373
+ **lm_kwargs,
374
+ )
375
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
364
376
  else:
365
377
  logits = self.language_model.lm_head(hidden_states)
366
378
  if labels is not None:
@@ -383,15 +395,39 @@ def lce_forward(
383
395
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
384
396
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
385
397
  loss = loss_fct(flat_logits, flat_labels)
398
+ elif shift_labels is not None:
399
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
400
+ logits = logits.float()
401
+ shift_logits = logits[..., :-1, :]
402
+ if attention_mask is not None:
403
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
404
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
405
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
406
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
407
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
408
+ else:
409
+ shift_logits = shift_logits.contiguous()
410
+ shift_labels = shift_labels.contiguous()
411
+ # Flatten the tokens
412
+ loss_fct = CrossEntropyLoss()
413
+
414
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
415
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
416
+ loss = loss_fct(flat_logits, flat_labels)
417
+
386
418
  if not return_dict:
387
419
  output = (logits,) + outputs[1:]
388
- return (loss,) + output if loss is not None else output
420
+ output = (loss,) + output if loss is not None else output
421
+ output = output + (token_accuracy,) if token_accuracy is not None else output
422
+ return output
389
423
 
390
- return PaliGemmaCausalLMOutputWithPast(
424
+ # Return PaliGemma output with token_accuracy field
425
+ return LigerPaliGemmaCausalLMOutputWithPast(
391
426
  loss=loss,
392
427
  logits=logits,
393
428
  past_key_values=outputs.past_key_values,
394
429
  hidden_states=outputs.hidden_states,
395
430
  attentions=outputs.attentions,
396
431
  image_hidden_states=image_features if pixel_values is not None else None,
432
+ token_accuracy=token_accuracy,
397
433
  )
@@ -5,131 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
- from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
- from transformers.utils.deprecation import deprecate_kwarg
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
- from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
-
19
-
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
- def lce_forward_deprecated(
23
- self,
24
- input_ids: torch.LongTensor = None,
25
- attention_mask: Optional[torch.Tensor] = None,
26
- position_ids: Optional[torch.LongTensor] = None,
27
- past_key_values: Optional[List[torch.FloatTensor]] = None,
28
- inputs_embeds: Optional[torch.FloatTensor] = None,
29
- labels: Optional[torch.LongTensor] = None,
30
- use_cache: Optional[bool] = None,
31
- output_attentions: Optional[bool] = None,
32
- output_hidden_states: Optional[bool] = None,
33
- return_dict: Optional[bool] = None,
34
- cache_position: Optional[torch.LongTensor] = None,
35
- ) -> Union[Tuple, CausalLMOutputWithPast]:
36
- r"""
37
- Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
38
-
39
-
40
- Args:
41
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
42
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
43
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
-
46
- Returns:
47
-
48
- Example:
49
-
50
- ```python
51
- >>> from transformers import AutoTokenizer, Phi3ForCausalLM
52
-
53
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
54
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
55
-
56
- >>> prompt = "This is an example script ."
57
- >>> inputs = tokenizer(prompt, return_tensors="pt")
58
-
59
- >>> # Generate
60
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
61
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
62
- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
63
- ```"""
64
-
65
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
66
- output_hidden_states = (
67
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
68
- )
69
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
70
9
 
71
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
72
- outputs = self.model(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- use_cache=use_cache,
79
- output_attentions=output_attentions,
80
- output_hidden_states=output_hidden_states,
81
- return_dict=return_dict,
82
- )
83
-
84
- hidden_states = outputs[0]
85
-
86
- loss = None
87
- logits = None
88
-
89
- if self.training and labels is not None:
90
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
- shift_labels = labels[..., 1:].contiguous()
92
-
93
- # flatten tokens
94
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
95
- shift_labels = shift_labels.view(-1)
96
-
97
- lce = LigerFusedLinearCrossEntropyLoss()
98
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
99
- else:
100
- logits = self.lm_head(hidden_states)
101
-
102
- loss = None
103
- if labels is not None:
104
- # Upcast to float if we need to compute the loss to avoid potential precision issues
105
- logits = logits.float()
106
- # Shift so that tokens < n predict n
107
- shift_logits = logits[..., :-1, :].contiguous()
108
- shift_labels = labels[..., 1:].contiguous()
109
- # Flatten the tokens
110
- loss_fct = CrossEntropyLoss()
111
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
112
- shift_labels = shift_labels.view(-1)
113
- # Enable model parallelism
114
- shift_labels = shift_labels.to(shift_logits.device)
115
- loss = loss_fct(shift_logits, shift_labels)
116
-
117
- if not return_dict:
118
- output = (logits,) + outputs[1:]
119
- return (loss,) + output if loss is not None else output
120
-
121
- return CausalLMOutputWithPast(
122
- loss=loss,
123
- logits=logits,
124
- past_key_values=outputs.past_key_values,
125
- hidden_states=outputs.hidden_states,
126
- attentions=outputs.attentions,
127
- )
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
128
13
 
129
14
 
130
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
131
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
132
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
15
  def lce_forward(
134
16
  self,
135
17
  input_ids: torch.LongTensor = None,
@@ -144,107 +26,95 @@ def lce_forward(
144
26
  return_dict: Optional[bool] = None,
145
27
  cache_position: Optional[torch.LongTensor] = None,
146
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
147
- **loss_kwargs,
148
- ) -> Union[Tuple, CausalLMOutputWithPast]:
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
149
32
  r"""
150
- Args:
151
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
152
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
153
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
154
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
155
-
156
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
157
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
158
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
159
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
160
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
161
- This is useful when using packed tensor format (single dimension for batch and sequence length).
162
-
163
- Returns:
164
-
165
33
  Example:
166
34
 
167
35
  ```python
168
36
  >>> from transformers import AutoTokenizer, Phi3ForCausalLM
169
37
 
170
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
171
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
38
+ >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
39
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
172
40
 
173
- >>> prompt = "This is an example script ."
41
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
174
42
  >>> inputs = tokenizer(prompt, return_tensors="pt")
175
43
 
176
44
  >>> # Generate
177
45
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
178
46
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
179
- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
47
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
180
48
  ```"""
181
49
 
182
- from transformers.models.phi3.modeling_phi3 import logging
183
-
184
- logger = logging.get_logger(__name__)
185
-
186
- if (
187
- use_cache
188
- and self.config.rope_scaling
189
- and cache_position is not None
190
- and cache_position[0] == self.config.original_max_position_embeddings
191
- ):
192
- logger.warning(
193
- f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
194
- )
195
-
196
50
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
197
51
  output_hidden_states = (
198
52
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
199
53
  )
200
54
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
201
55
 
202
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
203
- outputs = self.model(
56
+ outputs: BaseModelOutputWithPast = self.model(
204
57
  input_ids=input_ids,
205
58
  attention_mask=attention_mask,
206
59
  position_ids=position_ids,
207
60
  past_key_values=past_key_values,
208
61
  inputs_embeds=inputs_embeds,
209
62
  use_cache=use_cache,
210
- output_attentions=output_attentions,
211
- output_hidden_states=output_hidden_states,
212
- return_dict=return_dict,
63
+ cache_position=cache_position,
64
+ **kwargs,
213
65
  )
214
66
 
215
- hidden_states = outputs[0]
67
+ hidden_states = outputs.last_hidden_state
68
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
69
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
70
+ kept_hidden_states = hidden_states[:, slice_indices, :]
216
71
 
72
+ shift_labels = kwargs.pop("shift_labels", None)
217
73
  logits = None
218
74
  loss = None
219
- # if in training mode, don't materialize logits
220
- if self.training and (labels is not None):
221
- loss = LigerForCausalLMLoss(
222
- hidden_states=hidden_states,
75
+ token_accuracy = None
76
+
77
+ if skip_logits and labels is None and shift_labels is None:
78
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
79
+
80
+ if skip_logits is None:
81
+ # By default, if in training mode, don't materialize logits
82
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
83
+
84
+ # Compute loss
85
+ if skip_logits:
86
+ result = LigerForCausalLMLoss(
87
+ hidden_states=kept_hidden_states,
223
88
  lm_head_weight=self.lm_head.weight,
224
89
  labels=labels,
90
+ shift_labels=shift_labels,
225
91
  hidden_size=self.config.hidden_size,
226
- **loss_kwargs,
92
+ **kwargs,
227
93
  )
228
-
229
- else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
232
- if labels is not None:
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
95
+ else:
96
+ logits = self.lm_head(kept_hidden_states)
97
+ if labels is not None or shift_labels is not None:
233
98
  loss = self.loss_function(
234
99
  logits=logits,
235
100
  labels=labels,
101
+ shift_labels=shift_labels,
236
102
  vocab_size=self.config.vocab_size,
237
- **loss_kwargs,
103
+ **kwargs,
238
104
  )
239
105
 
240
106
  if not return_dict:
241
- output = (logits,) + outputs[1:]
242
- return (loss,) + output if loss is not None else output
107
+ output_tuple = (logits,) + outputs[1:]
108
+ output = (loss,) + output_tuple if loss is not None else output_tuple
109
+ output = output + (token_accuracy,) if token_accuracy is not None else output
110
+ return output
243
111
 
244
- return CausalLMOutputWithPast(
112
+ # Return custom output class with token_accuracy field
113
+ return LigerCausalLMOutputWithPast(
245
114
  loss=loss,
246
115
  logits=logits,
247
116
  past_key_values=outputs.past_key_values,
248
117
  hidden_states=outputs.hidden_states,
249
118
  attentions=outputs.attentions,
119
+ token_accuracy=token_accuracy,
250
120
  )
@@ -7,18 +7,14 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
16
 
19
17
 
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
18
  def lce_forward_deprecated(
23
19
  self,
24
20
  input_ids: torch.LongTensor = None,
@@ -32,6 +28,7 @@ def lce_forward_deprecated(
32
28
  output_hidden_states: Optional[bool] = None,
33
29
  return_dict: Optional[bool] = None,
34
30
  cache_position: Optional[torch.LongTensor] = None,
31
+ skip_logits: Optional[bool] = None,
35
32
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
33
  r"""
37
34
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -86,6 +83,13 @@ def lce_forward_deprecated(
86
83
  loss = None
87
84
  logits = None
88
85
 
86
+ if skip_logits and labels is None:
87
+ raise ValueError("skip_logits is True, but labels is None")
88
+
89
+ if skip_logits is None:
90
+ # By default, if in training mode, don't materialize logits
91
+ skip_logits = self.training and labels is not None
92
+
89
93
  if self.training and (labels is not None):
90
94
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
95
  shift_labels = labels[..., 1:].contiguous()
@@ -127,8 +131,6 @@ def lce_forward_deprecated(
127
131
 
128
132
 
129
133
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
134
  def lce_forward(
133
135
  self,
134
136
  input_ids: torch.LongTensor = None,
@@ -143,8 +145,9 @@ def lce_forward(
143
145
  return_dict: Optional[bool] = None,
144
146
  cache_position: Optional[torch.LongTensor] = None,
145
147
  logits_to_keep: Union[int, torch.Tensor] = 0,
146
- **loss_kwargs,
147
- ) -> Union[Tuple, CausalLMOutputWithPast]:
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
150
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
148
151
  r"""
149
152
  Args:
150
153
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -196,37 +199,61 @@ def lce_forward(
196
199
  output_hidden_states=output_hidden_states,
197
200
  return_dict=return_dict,
198
201
  cache_position=cache_position,
202
+ **kwargs,
199
203
  )
200
204
 
201
205
  hidden_states = outputs[0]
206
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
207
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
208
+ kept_hidden_states = hidden_states[:, slice_indices, :]
202
209
 
210
+ shift_labels = kwargs.pop("shift_labels", None)
203
211
  logits = None
204
212
  loss = None
205
- # if in training mode, don't materialize logits
206
- if self.training and (labels is not None):
207
- loss = LigerForCausalLMLoss(
208
- hidden_states=hidden_states,
213
+ token_accuracy = None
214
+
215
+ if skip_logits and labels is None and shift_labels is None:
216
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
217
+
218
+ if skip_logits is None:
219
+ # By default, if in training mode, don't materialize logits
220
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
221
+
222
+ # Compute loss
223
+ if skip_logits:
224
+ result = LigerForCausalLMLoss(
225
+ hidden_states=kept_hidden_states,
209
226
  lm_head_weight=self.lm_head.weight,
210
227
  labels=labels,
228
+ shift_labels=shift_labels,
211
229
  hidden_size=self.config.hidden_size,
212
- **loss_kwargs,
230
+ **kwargs,
213
231
  )
232
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
214
233
 
215
- else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
218
- if labels is not None:
234
+ else:
235
+ logits = self.lm_head(kept_hidden_states)
236
+ if labels is not None or shift_labels is not None:
219
237
  loss = self.loss_function(
220
238
  logits=logits,
221
239
  labels=labels,
240
+ shift_labels=shift_labels,
222
241
  vocab_size=self.config.vocab_size,
223
- **loss_kwargs,
242
+ **kwargs,
224
243
  )
225
244
 
226
- return CausalLMOutputWithPast(
245
+ if not return_dict:
246
+ output_tuple = (logits,) + outputs[1:]
247
+ output = (loss,) + output_tuple if loss is not None else output_tuple
248
+ output = output + (token_accuracy,) if token_accuracy is not None else output
249
+ return output
250
+
251
+ # Return custom output class with token accuracy field
252
+ return LigerCausalLMOutputWithPast(
227
253
  loss=loss,
228
254
  logits=logits,
229
255
  past_key_values=outputs.past_key_values,
230
256
  hidden_states=outputs.hidden_states,
231
257
  attentions=outputs.attentions,
258
+ token_accuracy=token_accuracy,
232
259
  )