liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import MoeModelOutputWithPast
8
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: Optional[torch.LongTensor] = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ output_router_logits: Optional[bool] = None,
27
+ cache_position: Optional[torch.LongTensor] = None,
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> LigerMoeCausalLMOutputWithPast:
32
+ r"""
33
+ Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations.
34
+
35
+ Args:
36
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers.
38
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
40
+ - 1 for tokens that are **not masked**,
41
+ - 0 for tokens that are **masked**.
42
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
43
+ Indices of positions of each input sequence tokens in the position embeddings.
44
+ past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*):
45
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
46
+ sequential decoding. See `past_key_values` input for more details.
47
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
48
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
49
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
50
+ than the model's internal embedding lookup matrix.
51
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
52
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
53
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
54
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
55
+ use_cache (`bool`, *optional*):
56
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
57
+ (see `past_key_values`).
58
+ output_attentions (`bool`, *optional*):
59
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
60
+ tensors for more detail.
61
+ output_hidden_states (`bool`, *optional*):
62
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
63
+ more detail.
64
+ output_router_logits (`bool`, *optional*):
65
+ Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors
66
+ for more detail.
67
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
68
+ Indices depicting the position of the input sequence tokens in the sequence.
69
+ logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
70
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
71
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
72
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
73
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
74
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
75
+ skip_logits (`bool`, *optional*):
76
+ Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training
77
+ when labels are provided (to save memory), and `False` during inference.
78
+
79
+ Returns:
80
+ `LigerMoeCausalLMOutputWithPast`: An output object containing:
81
+ - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
82
+ Language modeling loss (for next-token prediction), including the auxiliary load balancing loss.
83
+ - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
84
+ Auxiliary load balancing loss for the sparse MoE modules.
85
+ - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
86
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
87
+ Note: logits are `None` during training when `skip_logits=True` to save memory.
88
+ - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed):
89
+ Cached key and value projection states for faster sequential decoding.
90
+ - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
91
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape
92
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer.
93
+ - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
94
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
95
+ sequence_length)`. Attentions weights after the attention softmax.
96
+ - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`):
97
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
98
+ Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss.
99
+ - token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
100
+ Token-level prediction accuracy.
101
+
102
+ Example:
103
+
104
+ ```python
105
+ >>> from transformers import AutoTokenizer, GptOssForCausalLM
106
+ >>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
107
+
108
+ >>> # Apply Liger Kernel patches for optimized performance
109
+ >>> apply_liger_kernel_to_gpt_oss()
110
+
111
+ >>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b")
112
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
113
+
114
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
115
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
116
+
117
+ >>> # Inference: Forward pass returns logits
118
+ >>> outputs = model(**inputs)
119
+ >>> outputs.logits.shape
120
+ torch.Size([1, 12, 201088])
121
+
122
+ >>> # Get next token prediction
123
+ >>> next_token_logits = outputs.logits[:, -1, :]
124
+ >>> predicted_token_id = next_token_logits.argmax(dim=-1)
125
+
126
+ >>> # Training: Forward pass with labels returns loss
127
+ >>> labels = inputs.input_ids.clone()
128
+ >>> outputs = model(**inputs, labels=labels)
129
+ >>> outputs.loss
130
+ tensor(2.6454)
131
+ ```"""
132
+
133
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
134
+ output_router_logits = (
135
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
136
+ )
137
+
138
+ output_hidden_states = (
139
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
140
+ )
141
+
142
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
143
+ outputs: MoeModelOutputWithPast = self.model(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ past_key_values=past_key_values,
148
+ inputs_embeds=inputs_embeds,
149
+ use_cache=use_cache,
150
+ output_attentions=output_attentions,
151
+ output_hidden_states=output_hidden_states,
152
+ output_router_logits=output_router_logits,
153
+ cache_position=cache_position,
154
+ **kwargs,
155
+ )
156
+
157
+ hidden_states = outputs.last_hidden_state
158
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
159
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
160
+ kept_hidden_states = hidden_states[:, slice_indices, :]
161
+
162
+ shift_labels = kwargs.pop("shift_labels", None)
163
+ logits = None
164
+ loss = None
165
+ token_accuracy = None
166
+
167
+ if skip_logits is None:
168
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
169
+
170
+ if skip_logits:
171
+ result = LigerForCausalLMLoss(
172
+ hidden_states=kept_hidden_states,
173
+ lm_head_weight=self.lm_head.weight,
174
+ labels=labels,
175
+ shift_labels=shift_labels,
176
+ hidden_size=self.config.hidden_size,
177
+ **kwargs,
178
+ )
179
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
180
+ else: # if in inference model materialize logits
181
+ logits = self.lm_head(kept_hidden_states)
182
+ if labels is not None or shift_labels is not None:
183
+ loss = self.loss_function(
184
+ logits=logits,
185
+ labels=labels,
186
+ shift_labels=shift_labels,
187
+ vocab_size=self.vocab_size,
188
+ **kwargs,
189
+ )
190
+
191
+ aux_loss = None
192
+ if output_router_logits:
193
+ aux_loss = load_balancing_loss_func(
194
+ outputs.router_logits,
195
+ self.num_experts,
196
+ self.num_experts_per_tok,
197
+ attention_mask,
198
+ )
199
+ if labels is not None:
200
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
201
+
202
+ return LigerMoeCausalLMOutputWithPast(
203
+ loss=loss,
204
+ aux_loss=aux_loss,
205
+ logits=logits,
206
+ past_key_values=outputs.past_key_values,
207
+ hidden_states=outputs.hidden_states,
208
+ attentions=outputs.attentions,
209
+ router_logits=outputs.router_logits,
210
+ token_accuracy=token_accuracy,
211
+ )
@@ -0,0 +1,134 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM
48
+
49
+ >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ logits = None
87
+ loss = None
88
+ token_accuracy = None
89
+
90
+ if skip_logits and labels is None and shift_labels is None:
91
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
92
+
93
+ if skip_logits is None:
94
+ # By default, if in training mode, don't materialize logits
95
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
96
+
97
+ # Compute loss
98
+ if skip_logits:
99
+ result = LigerForCausalLMLoss(
100
+ hidden_states=kept_hidden_states,
101
+ lm_head_weight=self.lm_head.weight,
102
+ labels=labels,
103
+ shift_labels=shift_labels,
104
+ hidden_size=self.config.hidden_size,
105
+ **kwargs,
106
+ )
107
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
108
+
109
+ else:
110
+ logits = self.lm_head(kept_hidden_states)
111
+ if labels is not None or shift_labels is not None:
112
+ loss = self.loss_function(
113
+ logits=logits,
114
+ labels=labels,
115
+ shift_labels=shift_labels,
116
+ vocab_size=self.config.vocab_size,
117
+ **kwargs,
118
+ )
119
+
120
+ if not return_dict:
121
+ output = (logits,) + outputs[1:]
122
+ output = ((loss,) + output) if loss is not None else output
123
+ output = output + (token_accuracy,) if token_accuracy is not None else output
124
+ return output
125
+
126
+ # Return custom output class with accuracy field
127
+ return LigerCausalLMOutputWithPast(
128
+ loss=loss,
129
+ logits=logits,
130
+ past_key_values=outputs.past_key_values,
131
+ hidden_states=outputs.hidden_states,
132
+ attentions=outputs.attentions,
133
+ token_accuracy=token_accuracy,
134
+ )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.internvl.modeling_internvl import InternVLCausalLMOutputWithPast
9
8
  from transformers.utils import can_return_tuple
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  # Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862
@@ -33,7 +34,7 @@ def lce_forward(
33
34
  image_sizes: Optional[torch.Tensor] = None,
34
35
  skip_logits: Optional[bool] = None, # Added argument for liger-kernel
35
36
  **lm_kwargs, # renamed from kwargs
36
- ) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
37
+ ) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]:
37
38
  r"""
38
39
  Example:
39
40
 
@@ -111,6 +112,7 @@ def lce_forward(
111
112
  shift_labels = lm_kwargs.pop("shift_labels", None)
112
113
  logits = None
113
114
  loss = None
115
+ token_accuracy = None
114
116
 
115
117
  if skip_logits and labels is None and shift_labels is None:
116
118
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -120,7 +122,7 @@ def lce_forward(
120
122
  skip_logits = self.training and (labels is not None or shift_labels is not None)
121
123
 
122
124
  if skip_logits:
123
- loss = LigerForCausalLMLoss(
125
+ result = LigerForCausalLMLoss(
124
126
  hidden_states=kept_hidden_states,
125
127
  lm_head_weight=self.lm_head.weight,
126
128
  labels=labels,
@@ -128,6 +130,7 @@ def lce_forward(
128
130
  hidden_size=self.config.text_config.hidden_size,
129
131
  **lm_kwargs,
130
132
  )
133
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
131
134
 
132
135
  else:
133
136
  logits = self.lm_head(kept_hidden_states)
@@ -138,13 +141,17 @@ def lce_forward(
138
141
 
139
142
  if not return_dict:
140
143
  output = (logits,) + outputs[1:]
141
- return (loss,) + output if loss is not None else output
144
+ output = (loss,) + output if loss is not None else output
145
+ output = output + (token_accuracy,) if token_accuracy is not None else output
146
+ return output
142
147
 
143
- return InternVLCausalLMOutputWithPast(
148
+ # Return custom output class with token_accuracy field
149
+ return LigerInternVLCausalLMOutputWithPast(
144
150
  loss=loss,
145
151
  logits=logits,
146
152
  past_key_values=outputs.past_key_values,
147
153
  hidden_states=outputs.hidden_states,
148
154
  attentions=outputs.attentions,
149
155
  image_hidden_states=outputs.image_hidden_states,
156
+ token_accuracy=token_accuracy,
150
157
  )
@@ -15,6 +15,8 @@ from transformers.utils.deprecation import deprecate_kwarg
15
15
  from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
16
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
19
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
20
  from liger_kernel.utils import PEFT_AVAILABLE
19
21
 
20
22
  if TYPE_CHECKING:
@@ -162,7 +164,7 @@ def lce_forward(
162
164
  logits_to_keep: Union[int, torch.Tensor] = 0,
163
165
  skip_logits: Optional[bool] = None,
164
166
  **kwargs,
165
- ) -> Union[Tuple, CausalLMOutputWithPast]:
167
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
166
168
  r"""
167
169
  Args:
168
170
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -228,6 +230,8 @@ def lce_forward(
228
230
  shift_labels = kwargs.pop("shift_labels", None)
229
231
  logits = None
230
232
  loss = None
233
+ token_accuracy = None
234
+
231
235
  # if in training mode, don't materialize logits
232
236
  if skip_logits and labels is None and shift_labels is None:
233
237
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -236,8 +240,9 @@ def lce_forward(
236
240
  # By default, if in training mode, don't materialize logits
237
241
  skip_logits = self.training and (labels is not None or shift_labels is not None)
238
242
 
243
+ # Compute loss
239
244
  if skip_logits:
240
- loss = lce_maybe_trainable_lm_head(
245
+ result = lce_maybe_trainable_lm_head(
241
246
  self,
242
247
  hidden_states=kept_hidden_states,
243
248
  hidden_size=self.config.hidden_size,
@@ -245,7 +250,7 @@ def lce_forward(
245
250
  shift_labels=shift_labels,
246
251
  **kwargs,
247
252
  )
248
-
253
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
249
254
  else:
250
255
  logits = self.lm_head(kept_hidden_states)
251
256
  if labels is not None or shift_labels is not None:
@@ -259,14 +264,18 @@ def lce_forward(
259
264
 
260
265
  if not return_dict:
261
266
  output = (logits,) + outputs[1:]
262
- return (loss,) + output if loss is not None else output
267
+ output = ((loss,) + output) if loss is not None else output
268
+ output = output + (token_accuracy,) if token_accuracy is not None else output
269
+ return output
263
270
 
264
- return CausalLMOutputWithPast(
271
+ # Return custom output class with token_accuracy field
272
+ return LigerCausalLMOutputWithPast(
265
273
  loss=loss,
266
274
  logits=logits,
267
275
  past_key_values=outputs.past_key_values,
268
276
  hidden_states=outputs.hidden_states,
269
277
  attentions=outputs.attentions,
278
+ token_accuracy=token_accuracy,
270
279
  )
271
280
 
272
281
 
@@ -6,9 +6,10 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  def lce_forward(
@@ -26,7 +27,7 @@ def lce_forward(
26
27
  cache_position: Optional[torch.LongTensor] = None,
27
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
29
  **kwargs,
29
- ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
30
31
  r"""
31
32
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32
33
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -78,9 +79,11 @@ def lce_forward(
78
79
  shift_labels = kwargs.pop("shift_labels", None)
79
80
  logits = None
80
81
  loss = None
82
+ token_accuracy = None
81
83
 
84
+ # Compute loss
82
85
  if self.training and (labels is not None or shift_labels is not None):
83
- loss = LigerForCausalLMLoss(
86
+ result = LigerForCausalLMLoss(
84
87
  hidden_states=kept_hidden_states,
85
88
  lm_head_weight=self.lm_head.weight,
86
89
  labels=labels,
@@ -88,6 +91,7 @@ def lce_forward(
88
91
  hidden_size=self.config.hidden_size,
89
92
  **kwargs,
90
93
  )
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
91
95
 
92
96
  else: # if in inference mode materialize logits
93
97
  logits = self.lm_head(kept_hidden_states)
@@ -100,10 +104,18 @@ def lce_forward(
100
104
  **kwargs,
101
105
  )
102
106
 
103
- return CausalLMOutputWithPast(
107
+ if not return_dict:
108
+ output = (logits,) + outputs[1:]
109
+ output = ((loss,) + output) if loss is not None else output
110
+ output = output + (token_accuracy,) if token_accuracy is not None else output
111
+ return output
112
+
113
+ # Return custom output class with token_accuracy field
114
+ return LigerCausalLMOutputWithPast(
104
115
  loss=loss,
105
116
  logits=logits,
106
117
  past_key_values=outputs.past_key_values,
107
118
  hidden_states=outputs.hidden_states,
108
119
  attentions=outputs.attentions,
120
+ token_accuracy=token_accuracy,
109
121
  )
@@ -11,6 +11,8 @@ from transformers.utils import is_torchdynamo_compiling
11
11
 
12
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast
14
16
 
15
17
 
16
18
  def lce_forward_deprecated(
@@ -215,7 +217,7 @@ def lce_forward(
215
217
  image_sizes: torch.Tensor = None,
216
218
  skip_logits: Optional[bool] = None,
217
219
  **lm_kwargs,
218
- ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
220
+ ) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]:
219
221
  r"""
220
222
  Args:
221
223
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -293,6 +295,7 @@ def lce_forward(
293
295
  shift_labels = lm_kwargs.pop("shift_labels", None)
294
296
  logits = None
295
297
  loss = None
298
+ token_accuracy = None
296
299
 
297
300
  if skip_logits and labels is None and shift_labels is None:
298
301
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -302,7 +305,7 @@ def lce_forward(
302
305
  skip_logits = self.training and (labels is not None or shift_labels is not None)
303
306
 
304
307
  if skip_logits:
305
- loss = LigerForCausalLMLoss(
308
+ result = LigerForCausalLMLoss(
306
309
  hidden_states=kept_hidden_states,
307
310
  lm_head_weight=self.lm_head.weight,
308
311
  labels=labels,
@@ -310,6 +313,7 @@ def lce_forward(
310
313
  hidden_size=self.config.text_config.hidden_size,
311
314
  **lm_kwargs,
312
315
  )
316
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
313
317
 
314
318
  else:
315
319
  logits = self.lm_head(kept_hidden_states)
@@ -324,13 +328,17 @@ def lce_forward(
324
328
 
325
329
  if not return_dict:
326
330
  output = (logits,) + outputs[1:]
327
- return (loss,) + output if loss is not None else output
331
+ output = (loss,) + output if loss is not None else output
332
+ output = output + (token_accuracy,) if token_accuracy is not None else output
333
+ return output
328
334
 
329
- return LlavaCausalLMOutputWithPast(
335
+ # Return custom output class with token_accuracy field
336
+ return LigerLlavaCausalLMOutputWithPast(
330
337
  loss=loss,
331
338
  logits=logits,
332
339
  past_key_values=outputs.past_key_values,
333
340
  hidden_states=outputs.hidden_states,
334
341
  attentions=outputs.attentions,
335
342
  image_hidden_states=outputs.image_hidden_states,
343
+ token_accuracy=token_accuracy,
336
344
  )
@@ -1,10 +1,28 @@
1
1
  from typing import Optional
2
+ from typing import Tuple
2
3
 
3
4
  import torch
4
5
  import torch.nn as nn
5
6
 
6
7
  import liger_kernel.transformers.functional as F
7
8
 
9
+ from liger_kernel.transformers.functional import CrossEntropyOutput
10
+
11
+
12
+ def unpack_cross_entropy_result(
13
+ result,
14
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
15
+ if isinstance(result, CrossEntropyOutput):
16
+ return result.loss, result.z_loss, result.token_accuracy
17
+
18
+ if isinstance(result, tuple):
19
+ loss = result[0]
20
+ z_loss = result[1] if len(result) > 1 else None
21
+ token_accuracy = result[2] if len(result) > 2 else None
22
+ return loss, z_loss, token_accuracy
23
+
24
+ return result, None, None
25
+
8
26
 
9
27
  def fixed_fused_linear_cross_entropy(
10
28
  hidden_states: torch.Tensor,
@@ -14,10 +32,11 @@ def fixed_fused_linear_cross_entropy(
14
32
  ignore_index: int = -100,
15
33
  final_logit_softcapping: Optional[float] = None,
16
34
  accum_dtype: Optional[torch.dtype] = None,
35
+ return_token_accuracy: bool = False,
17
36
  **kwargs,
18
37
  ):
19
38
  reduction = "sum" if num_items_in_batch is not None else "mean"
20
- loss = F.liger_fused_linear_cross_entropy(
39
+ result = F.liger_fused_linear_cross_entropy(
21
40
  hidden_states,
22
41
  lm_head_weight,
23
42
  target,
@@ -25,11 +44,18 @@ def fixed_fused_linear_cross_entropy(
25
44
  ignore_index=ignore_index,
26
45
  softcap=final_logit_softcapping,
27
46
  accum_dtype=accum_dtype,
47
+ return_token_accuracy=return_token_accuracy,
28
48
  **kwargs,
29
49
  )
50
+
51
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
52
+
30
53
  if reduction == "sum":
31
54
  loss = loss / num_items_in_batch
32
55
 
56
+ if return_token_accuracy:
57
+ return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy)
58
+
33
59
  return loss
34
60
 
35
61
 
@@ -42,6 +68,7 @@ def LigerForCausalLMLoss(
42
68
  ignore_index: int = -100,
43
69
  shift_labels: Optional[torch.Tensor] = None,
44
70
  final_logit_softcapping: Optional[float] = None,
71
+ return_token_accuracy: bool = False,
45
72
  **kwargs,
46
73
  ):
47
74
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
@@ -55,13 +82,14 @@ def LigerForCausalLMLoss(
55
82
  shift_labels = shift_labels.view(-1)
56
83
  # Enable model parallelism
57
84
  shift_labels = shift_labels.to(hidden_states.device)
58
- loss = fixed_fused_linear_cross_entropy(
85
+ result = fixed_fused_linear_cross_entropy(
59
86
  hidden_states,
60
87
  lm_head_weight,
61
88
  shift_labels,
62
89
  num_items_in_batch,
63
90
  ignore_index,
64
91
  final_logit_softcapping,
92
+ return_token_accuracy=return_token_accuracy,
65
93
  **kwargs,
66
94
  )
67
- return loss
95
+ return result