liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,19 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
4
- from torch.nn import CrossEntropyLoss
7
+
5
8
  from transformers.cache_utils import Cache
6
- from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.mistral.modeling_mistral import (
8
- _CONFIG_FOR_DOC,
9
- MISTRAL_INPUTS_DOCSTRING,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
19
-
20
-
21
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+
15
+
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25
17
  def lce_forward(
26
18
  self,
27
19
  input_ids: torch.LongTensor = None,
@@ -35,7 +27,10 @@ def lce_forward(
35
27
  output_hidden_states: Optional[bool] = None,
36
28
  return_dict: Optional[bool] = None,
37
29
  cache_position: Optional[torch.LongTensor] = None,
38
- ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
39
34
  r"""
40
35
  Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
41
36
 
@@ -46,6 +41,12 @@ def lce_forward(
46
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
43
 
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
49
50
  Returns:
50
51
 
51
52
  Example:
@@ -65,19 +66,11 @@ def lce_forward(
65
66
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
67
  ```"""
67
68
 
68
- output_attentions = (
69
- output_attentions
70
- if output_attentions is not None
71
- else self.config.output_attentions
72
- )
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
70
  output_hidden_states = (
74
- output_hidden_states
75
- if output_hidden_states is not None
76
- else self.config.output_hidden_states
77
- )
78
- return_dict = (
79
- return_dict if return_dict is not None else self.config.use_return_dict
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
72
  )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
74
 
82
75
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83
76
  outputs = self.model(
@@ -91,51 +84,62 @@ def lce_forward(
91
84
  output_hidden_states=output_hidden_states,
92
85
  return_dict=return_dict,
93
86
  cache_position=cache_position,
87
+ **kwargs,
94
88
  )
95
89
 
96
90
  hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
97
94
 
95
+ shift_labels = kwargs.pop("shift_labels", None)
98
96
  loss = None
99
97
  logits = None
100
-
101
- if self.training and (labels is not None):
102
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
103
- shift_labels = labels[..., 1:].contiguous()
104
-
105
- # flatten tokens
106
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107
- shift_labels = shift_labels.view(-1)
108
-
109
- lce = LigerFusedLinearCrossEntropyLoss()
110
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
98
+ token_accuracy = None
99
+
100
+ if skip_logits and labels is None and shift_labels is None:
101
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
102
+
103
+ if skip_logits is None:
104
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
105
+
106
+ # Compute loss
107
+ if skip_logits:
108
+ result = LigerForCausalLMLoss(
109
+ hidden_states=kept_hidden_states,
110
+ lm_head_weight=self.lm_head.weight,
111
+ labels=labels,
112
+ shift_labels=shift_labels,
113
+ hidden_size=self.config.hidden_size,
114
+ **kwargs,
115
+ )
116
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
111
117
 
112
118
  else:
113
- logits = self.lm_head(hidden_states)
114
- if labels is not None:
115
- # Upcast to float if we need to compute the loss to avoid potential precision issues
116
- logits = logits.float()
117
- # Shift so that tokens < n predict n
118
- shift_logits = logits[..., :-1, :].contiguous()
119
- shift_labels = labels[..., 1:].contiguous()
120
- # Flatten the tokens
121
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
122
- shift_labels = shift_labels.view(-1)
123
- # Ensure tensors are on the same device
124
- shift_labels = shift_labels.to(shift_logits.device)
125
- loss_fct = CrossEntropyLoss()
126
- loss = loss_fct(shift_logits, shift_labels)
119
+ logits = self.lm_head(kept_hidden_states)
120
+
121
+ loss = None
122
+ if labels is not None or shift_labels is not None:
123
+ loss = self.loss_function(
124
+ logits=logits,
125
+ labels=labels,
126
+ shift_labels=shift_labels,
127
+ vocab_size=self.config.vocab_size,
128
+ **kwargs,
129
+ )
127
130
 
128
131
  if not return_dict:
129
- output = (logits,) + outputs[1:]
130
- return (loss,) + output if loss is not None else output
132
+ output_tuple = (logits,) + outputs[1:]
133
+ output = (loss,) + output_tuple if loss is not None else output_tuple
134
+ output = output + (token_accuracy,) if token_accuracy is not None else output
135
+ return output
131
136
 
132
- return CausalLMOutputWithPast(
137
+ # Return custom output class with token_accuracy field
138
+ return LigerCausalLMOutputWithPast(
133
139
  loss=loss,
134
140
  logits=logits,
135
141
  past_key_values=outputs.past_key_values,
136
142
  hidden_states=outputs.hidden_states,
137
143
  attentions=outputs.attentions,
144
+ token_accuracy=token_accuracy,
138
145
  )
139
-
140
-
141
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -1,27 +1,21 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
6
- from transformers.models.mixtral.modeling_mixtral import (
7
- _CONFIG_FOR_DOC,
8
- MIXTRAL_INPUTS_DOCSTRING,
9
- load_balancing_loss_func,
10
- )
11
- from transformers.utils import (
12
- add_start_docstrings_to_model_forward,
13
- replace_return_docstrings,
14
- )
15
-
16
- from liger_kernel.transformers.fused_linear_cross_entropy import (
17
- LigerFusedLinearCrossEntropyLoss,
18
- )
19
-
20
-
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(
23
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
- )
10
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
11
+ from transformers.utils.deprecation import deprecate_kwarg
12
+
13
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
17
+
18
+
25
19
  def lce_forward_deprecated(
26
20
  self,
27
21
  input_ids: torch.LongTensor = None,
@@ -38,7 +32,7 @@ def lce_forward_deprecated(
38
32
  cache_position: Optional[torch.LongTensor] = None,
39
33
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
40
34
  r"""
41
- Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
35
+ Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
42
36
 
43
37
 
44
38
  Args:
@@ -66,25 +60,15 @@ def lce_forward_deprecated(
66
60
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
67
61
  ```"""
68
62
 
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
64
  output_router_logits = (
75
- output_router_logits
76
- if output_router_logits is not None
77
- else self.config.output_router_logits
65
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
78
66
  )
79
67
 
80
68
  output_hidden_states = (
81
- output_hidden_states
82
- if output_hidden_states is not None
83
- else self.config.output_hidden_states
84
- )
85
- return_dict = (
86
- return_dict if return_dict is not None else self.config.use_return_dict
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
87
70
  )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
72
 
89
73
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
90
74
  outputs = self.model(
@@ -138,9 +122,7 @@ def lce_forward_deprecated(
138
122
  attention_mask,
139
123
  )
140
124
  if labels is not None:
141
- loss += self.router_aux_loss_coef * aux_loss.to(
142
- loss.device
143
- ) # make sure to reside in the same device
125
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
144
126
 
145
127
  if not return_dict:
146
128
  output = (logits,) + outputs[1:]
@@ -159,10 +141,7 @@ def lce_forward_deprecated(
159
141
  )
160
142
 
161
143
 
162
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
- @replace_return_docstrings(
164
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
- )
144
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
166
145
  # Ignore copy
167
146
  def lce_forward(
168
147
  self,
@@ -178,9 +157,10 @@ def lce_forward(
178
157
  output_router_logits: Optional[bool] = None,
179
158
  return_dict: Optional[bool] = None,
180
159
  cache_position: Optional[torch.LongTensor] = None,
181
- num_logits_to_keep: int = 0,
182
- **loss_kwargs,
183
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
160
+ logits_to_keep: Union[int, torch.Tensor] = 0,
161
+ skip_logits: Optional[bool] = None,
162
+ **kwargs,
163
+ ) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]:
184
164
  r"""
185
165
  Args:
186
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -188,10 +168,12 @@ def lce_forward(
188
168
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
189
169
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
190
170
 
191
- num_logits_to_keep (`int`, *optional*):
192
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
171
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
172
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
193
173
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
194
174
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
175
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
176
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
195
177
 
196
178
  Returns:
197
179
 
@@ -212,25 +194,15 @@ def lce_forward(
212
194
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
195
  ```"""
214
196
 
215
- output_attentions = (
216
- output_attentions
217
- if output_attentions is not None
218
- else self.config.output_attentions
219
- )
197
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220
198
  output_router_logits = (
221
- output_router_logits
222
- if output_router_logits is not None
223
- else self.config.output_router_logits
199
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
224
200
  )
225
201
 
226
202
  output_hidden_states = (
227
- output_hidden_states
228
- if output_hidden_states is not None
229
- else self.config.output_hidden_states
230
- )
231
- return_dict = (
232
- return_dict if return_dict is not None else self.config.use_return_dict
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
204
  )
205
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
206
 
235
207
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
208
  outputs = self.model(
@@ -245,40 +217,50 @@ def lce_forward(
245
217
  output_router_logits=output_router_logits,
246
218
  return_dict=return_dict,
247
219
  cache_position=cache_position,
220
+ **kwargs,
248
221
  )
249
222
 
250
223
  hidden_states = outputs[0]
224
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
225
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
226
+ kept_hidden_states = hidden_states[:, slice_indices, :]
251
227
 
228
+ shift_labels = kwargs.pop("shift_labels", None)
252
229
  logits = None
253
230
  loss = None
254
- # if in training mode, don't materialize logits
255
- if self.training and (labels is not None):
256
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
257
-
258
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
259
- shift_labels = labels[..., 1:].contiguous()
260
-
261
- # flatten tokens
262
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
263
- shift_labels = shift_labels.view(-1)
264
-
265
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
266
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
231
+ token_accuracy = None
232
+
233
+ if skip_logits and labels is None and shift_labels is None:
234
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
235
+
236
+ if skip_logits is None:
237
+ # By default, if in training mode, don't materialize logits
238
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
239
+
240
+ # Compute loss
241
+ if skip_logits:
242
+ result = LigerForCausalLMLoss(
243
+ hidden_states=kept_hidden_states,
244
+ lm_head_weight=self.lm_head.weight,
245
+ labels=labels,
246
+ shift_labels=shift_labels,
247
+ hidden_size=self.config.hidden_size,
248
+ **kwargs,
249
+ )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
267
251
 
268
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
269
- if reduction == "sum":
270
- loss /= loss_kwargs["num_items_in_batch"]
252
+ else:
253
+ logits = self.lm_head(kept_hidden_states)
271
254
 
272
- else: # if in inference mode materialize logits
273
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
274
- if labels is not None:
255
+ loss = None
256
+ if labels is not None or shift_labels is not None:
275
257
  loss = self.loss_function(
276
258
  logits=logits,
277
259
  labels=labels,
278
- vocab_size=self.config.vocab_size,
279
- **loss_kwargs,
260
+ shift_labels=shift_labels,
261
+ vocab_size=self.vocab_size,
262
+ **kwargs,
280
263
  )
281
-
282
264
  aux_loss = None
283
265
  if output_router_logits:
284
266
  aux_loss = load_balancing_loss_func(
@@ -288,22 +270,24 @@ def lce_forward(
288
270
  attention_mask,
289
271
  )
290
272
  if labels is not None:
291
- loss += self.router_aux_loss_coef * aux_loss.to(
292
- loss.device
293
- ) # make sure to reside in the same device
273
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
294
274
 
295
275
  if not return_dict:
296
- output = (logits,) + outputs[1:]
276
+ output_tuple = (logits,) + outputs[1:]
297
277
  if output_router_logits:
298
- output = (aux_loss,) + output
299
- return (loss,) + output if loss is not None else output
278
+ output_tuple = (aux_loss,) + output_tuple
279
+ if token_accuracy is not None:
280
+ output_tuple = output_tuple + (token_accuracy,)
281
+ return (loss,) + output_tuple if loss is not None else output_tuple
300
282
 
301
- return MoeCausalLMOutputWithPast(
283
+ # Return custom output class with token_accuracy field
284
+ return LigerMoeCausalLMOutputWithPast(
302
285
  loss=loss,
303
286
  aux_loss=aux_loss,
304
287
  logits=logits,
305
288
  past_key_values=outputs.past_key_values,
306
289
  hidden_states=outputs.hidden_states,
307
290
  attentions=outputs.attentions,
308
- router_logits=outputs.router_logits,
291
+ router_logits=outputs.router_logits if return_dict else outputs[-1],
292
+ token_accuracy=token_accuracy,
309
293
  )
@@ -1,24 +1,21 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
2
5
 
3
6
  import torch
7
+
4
8
  from torch.nn import CrossEntropyLoss
5
9
  from transformers.cache_utils import Cache
6
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
- from transformers.utils import (
9
- add_start_docstrings_to_model_forward,
10
- replace_return_docstrings,
11
- )
11
+ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
- from liger_kernel.transformers.fused_linear_cross_entropy import (
14
- LigerFusedLinearCrossEntropyLoss,
15
- )
13
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
17
 
17
18
 
18
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(
20
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
- )
22
19
  def lce_forward_deprecated(
23
20
  self,
24
21
  input_ids: torch.LongTensor = None,
@@ -66,19 +63,11 @@ def lce_forward_deprecated(
66
63
  I love the idea of snowflakes gently falling, each one
67
64
  ```
68
65
  """
69
- output_attentions = (
70
- output_attentions
71
- if output_attentions is not None
72
- else self.config.output_attentions
73
- )
66
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
67
  output_hidden_states = (
75
- output_hidden_states
76
- if output_hidden_states is not None
77
- else self.config.output_hidden_states
78
- )
79
- return_dict = (
80
- return_dict if return_dict is not None else self.config.use_return_dict
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
69
  )
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
71
 
83
72
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
73
  outputs = self.model(
@@ -142,10 +131,7 @@ def lce_forward_deprecated(
142
131
  )
143
132
 
144
133
 
145
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
- @replace_return_docstrings(
147
- output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
- )
134
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
135
  def lce_forward(
150
136
  self,
151
137
  input_ids: torch.LongTensor = None,
@@ -162,9 +148,10 @@ def lce_forward(
162
148
  output_hidden_states: Optional[bool] = None,
163
149
  return_dict: Optional[bool] = None,
164
150
  cache_position: Optional[torch.LongTensor] = None,
165
- num_logits_to_keep: int = 0,
166
- **loss_kwargs,
167
- ) -> Union[Tuple, CausalLMOutputWithPast]:
151
+ logits_to_keep: Union[int, torch.Tensor] = 0,
152
+ skip_logits: Optional[bool] = None,
153
+ **kwargs,
154
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
168
155
  r"""
169
156
  Args:
170
157
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -172,10 +159,12 @@ def lce_forward(
172
159
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
160
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
161
 
175
- num_logits_to_keep (`int`, *optional*):
176
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
162
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
163
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
177
164
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
165
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
166
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
167
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
179
168
 
180
169
  Returns:
181
170
 
@@ -198,20 +187,14 @@ def lce_forward(
198
187
  I love the idea of snowflakes gently falling, each one
199
188
  ```
200
189
  """
201
- output_attentions = (
202
- output_attentions
203
- if output_attentions is not None
204
- else self.config.output_attentions
205
- )
190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
191
  output_hidden_states = (
207
- output_hidden_states
208
- if output_hidden_states is not None
209
- else self.config.output_hidden_states
210
- )
211
- return_dict = (
212
- return_dict if return_dict is not None else self.config.use_return_dict
192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
213
193
  )
214
-
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+ # Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
196
+ # but preserve it for loss function calls
197
+ model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
215
198
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
199
  outputs = self.model(
217
200
  input_ids=input_ids,
@@ -227,48 +210,60 @@ def lce_forward(
227
210
  output_hidden_states=output_hidden_states,
228
211
  return_dict=return_dict,
229
212
  cache_position=cache_position,
213
+ **model_kwargs,
230
214
  )
231
215
 
232
216
  hidden_states = outputs[0]
217
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
218
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
219
+ kept_hidden_states = hidden_states[:, slice_indices, :]
233
220
 
221
+ shift_labels = kwargs.pop("shift_labels", None)
234
222
  logits = None
235
223
  loss = None
236
- # if in training mode, don't materialize logits
237
- if self.training and (labels is not None):
238
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
-
240
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
- shift_labels = labels[..., 1:].contiguous()
242
-
243
- # flatten tokens
244
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
- shift_labels = shift_labels.view(-1)
246
-
247
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
224
+ token_accuracy = None
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
234
+ result = LigerForCausalLMLoss(
235
+ hidden_states=kept_hidden_states,
236
+ lm_head_weight=self.lm_head.weight,
237
+ labels=labels,
238
+ shift_labels=shift_labels,
239
+ hidden_size=self.config.hidden_size,
240
+ **kwargs,
241
+ )
242
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
249
243
 
250
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
- if reduction == "sum":
252
- loss /= loss_kwargs["num_items_in_batch"]
253
-
254
- else: # if in inference mode materialize logits
255
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
- if labels is not None:
244
+ else:
245
+ logits = self.lm_head(kept_hidden_states)
246
+ if labels is not None or shift_labels is not None:
257
247
  loss = self.loss_function(
258
248
  logits=logits,
259
249
  labels=labels,
250
+ shift_labels=shift_labels,
260
251
  vocab_size=self.config.vocab_size,
261
- **loss_kwargs,
252
+ **kwargs,
262
253
  )
263
254
 
264
255
  if not return_dict:
265
256
  output = (logits,) + outputs[1:]
266
- return (loss,) + output if loss is not None else output
257
+ output = (loss,) + output if loss is not None else output
258
+ output = output + (token_accuracy,) if token_accuracy is not None else output
259
+ return output
267
260
 
268
- return CausalLMOutputWithPast(
261
+ # Return custom output class with token_accuracy field
262
+ return LigerCausalLMOutputWithPast(
269
263
  loss=loss,
270
264
  logits=logits,
271
265
  past_key_values=outputs.past_key_values,
272
266
  hidden_states=outputs.hidden_states,
273
267
  attentions=outputs.attentions,
268
+ token_accuracy=token_accuracy,
274
269
  )