liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ````python
47
+ >>> from transformers import AutoTokenizer, Exaone4ForCausalLM
48
+
49
+ >>> model = Exaone4ForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ # Remove output-control parameters that shouldn't be passed to loss functions
87
+ kwargs.pop("return_dict", None)
88
+ logits = None
89
+ loss = None
90
+ token_accuracy = None
91
+
92
+ if skip_logits and labels is None and shift_labels is None:
93
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
94
+
95
+ if skip_logits is None:
96
+ # By default, if in training mode, don't materialize logits
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ # Compute loss
100
+ if skip_logits:
101
+ result = LigerForCausalLMLoss(
102
+ hidden_states=kept_hidden_states,
103
+ lm_head_weight=self.lm_head.weight,
104
+ labels=labels,
105
+ shift_labels=shift_labels,
106
+ hidden_size=self.config.hidden_size,
107
+ **kwargs,
108
+ )
109
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
110
+
111
+ else:
112
+ logits = self.lm_head(kept_hidden_states)
113
+ if labels is not None or shift_labels is not None:
114
+ loss = self.loss_function(
115
+ logits=logits,
116
+ labels=labels,
117
+ shift_labels=shift_labels,
118
+ vocab_size=self.config.vocab_size,
119
+ **kwargs,
120
+ )
121
+
122
+ if not return_dict:
123
+ output = (logits,) + outputs[1:]
124
+ output = ((loss,) + output) if loss is not None else output
125
+ output = output + (token_accuracy,) if token_accuracy is not None else output
126
+ return output
127
+
128
+ # Return custom output class with accuracy field
129
+ return LigerCausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ token_accuracy=token_accuracy,
136
+ )
@@ -0,0 +1,122 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ if TYPE_CHECKING:
8
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
32
+ r"""
33
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import AutoTokenizer, FalconH1ForCausalLM
42
+
43
+ >>> model = FalconH1ForCausalLM.from_pretrained("...")
44
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
45
+
46
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
47
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
48
+
49
+ >>> # Generate
50
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
51
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
52
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
53
+ ```"""
54
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
55
+ output_hidden_states = (
56
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+
60
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
61
+ outputs = self.model(
62
+ input_ids=input_ids,
63
+ attention_mask=attention_mask,
64
+ position_ids=position_ids,
65
+ past_key_values=past_key_values,
66
+ inputs_embeds=inputs_embeds,
67
+ use_cache=use_cache,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # if in training mode, don't materialize logits
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ # Compute loss
93
+ if skip_logits:
94
+ result = LigerForCausalLMLoss(
95
+ hidden_states=kept_hidden_states,
96
+ lm_head_weight=self.lm_head.weight,
97
+ labels=labels,
98
+ shift_labels=shift_labels,
99
+ hidden_size=self.config.hidden_size,
100
+ **kwargs,
101
+ )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
103
+ else:
104
+ logits = self.lm_head(kept_hidden_states)
105
+ if labels is not None or shift_labels is not None:
106
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
116
+ loss=loss,
117
+ logits=logits,
118
+ past_key_values=outputs.past_key_values,
119
+ hidden_states=outputs.hidden_states,
120
+ attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
122
+ )
@@ -8,17 +8,14 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
11
+ from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
17
 
19
18
 
20
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
19
  def lce_forward_deprecated(
23
20
  self,
24
21
  input_ids: torch.LongTensor = None,
@@ -32,6 +29,7 @@ def lce_forward_deprecated(
32
29
  output_hidden_states: Optional[bool] = None,
33
30
  return_dict: Optional[bool] = None,
34
31
  cache_position: Optional[torch.LongTensor] = None,
32
+ skip_logits: Optional[bool] = None,
35
33
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
34
  r"""
37
35
 
@@ -86,7 +84,14 @@ def lce_forward_deprecated(
86
84
  loss = None
87
85
  logits = None
88
86
 
89
- if self.training and (labels is not None):
87
+ if skip_logits and labels is None:
88
+ raise ValueError("skip_logits is True, but labels is None")
89
+
90
+ if skip_logits is None:
91
+ # By default, if in training mode, don't materialize logits
92
+ skip_logits = self.training and labels is not None
93
+
94
+ if skip_logits:
90
95
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
96
  shift_labels = labels[..., 1:].contiguous()
92
97
 
@@ -127,8 +132,7 @@ def lce_forward_deprecated(
127
132
  )
128
133
 
129
134
 
130
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
135
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
136
  def lce_forward(
133
137
  self,
134
138
  input_ids: torch.LongTensor = None,
@@ -142,9 +146,10 @@ def lce_forward(
142
146
  output_hidden_states: Optional[bool] = None,
143
147
  return_dict: Optional[bool] = None,
144
148
  cache_position: Optional[torch.LongTensor] = None,
145
- num_logits_to_keep: int = 0,
146
- **loss_kwargs,
147
- ) -> Union[Tuple, CausalLMOutputWithPast]:
149
+ logits_to_keep: Union[int, torch.Tensor] = 0,
150
+ skip_logits: Optional[bool] = None,
151
+ **kwargs,
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
148
153
  r"""
149
154
  Args:
150
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -152,10 +157,12 @@ def lce_forward(
152
157
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
153
158
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
154
159
 
155
- num_logits_to_keep (`int`, *optional*):
156
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
160
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
161
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
157
162
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
158
163
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
164
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
165
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
159
166
 
160
167
  Returns:
161
168
 
@@ -193,39 +200,62 @@ def lce_forward(
193
200
  output_hidden_states=output_hidden_states,
194
201
  return_dict=return_dict,
195
202
  cache_position=cache_position,
203
+ **kwargs,
196
204
  )
197
205
 
198
206
  hidden_states = outputs[0]
207
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
208
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
209
+ kept_hidden_states = hidden_states[:, slice_indices, :]
199
210
 
211
+ shift_labels = kwargs.pop("shift_labels", None)
200
212
  logits = None
201
213
  loss = None
202
- # if in training mode, don't materialize logits
203
- if self.training and (labels is not None):
204
- loss = LigerForCausalLMLoss(
205
- hidden_states=hidden_states,
214
+ token_accuracy = None
215
+
216
+ if skip_logits and labels is None and shift_labels is None:
217
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
218
+
219
+ if skip_logits is None:
220
+ # By default, if in training mode, don't materialize logits
221
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
222
+
223
+ # Compute loss
224
+ if skip_logits:
225
+ result = LigerForCausalLMLoss(
226
+ hidden_states=kept_hidden_states,
206
227
  lm_head_weight=self.lm_head.weight,
207
228
  labels=labels,
229
+ shift_labels=shift_labels,
208
230
  hidden_size=self.config.hidden_size,
209
- **loss_kwargs,
231
+ **kwargs,
210
232
  )
211
- else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
213
- if labels is not None:
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
234
+ else:
235
+ logits = self.lm_head(kept_hidden_states)
236
+ if labels is not None or shift_labels is not None:
214
237
  loss = self.loss_function(
215
238
  logits=logits,
216
239
  labels=labels,
240
+ shift_labels=shift_labels,
217
241
  vocab_size=self.config.vocab_size,
218
- **loss_kwargs,
242
+ **kwargs,
219
243
  )
220
244
 
221
245
  if not return_dict:
222
- output = (logits,) + outputs[1:]
223
- return (loss,) + output if loss is not None else output
224
-
225
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
226
255
  loss=loss,
227
256
  logits=logits,
228
257
  past_key_values=outputs.past_key_values,
229
258
  hidden_states=outputs.hidden_states,
230
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
231
261
  )
@@ -7,15 +7,14 @@ from typing import Union
7
7
  import torch
8
8
 
9
9
  from torch.nn import CrossEntropyLoss
10
- from transformers.cache_utils import HybridCache
10
+ from transformers.cache_utils import Cache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
12
+ from transformers.utils.deprecation import deprecate_kwarg
16
13
 
17
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
17
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
19
18
 
20
19
  logger = logging.getLogger(__name__)
21
20
 
@@ -25,7 +24,7 @@ def lce_forward_deprecated(
25
24
  input_ids: torch.LongTensor = None,
26
25
  attention_mask: Optional[torch.Tensor] = None,
27
26
  position_ids: Optional[torch.LongTensor] = None,
28
- past_key_values: Optional[HybridCache] = None,
27
+ past_key_values: Optional[Cache] = None,
29
28
  inputs_embeds: Optional[torch.FloatTensor] = None,
30
29
  labels: Optional[torch.LongTensor] = None,
31
30
  use_cache: Optional[bool] = None,
@@ -33,6 +32,8 @@ def lce_forward_deprecated(
33
32
  output_hidden_states: Optional[bool] = None,
34
33
  return_dict: Optional[bool] = None,
35
34
  cache_position: Optional[torch.LongTensor] = None,
35
+ skip_logits: Optional[bool] = None,
36
+ **kwargs,
36
37
  ) -> Union[Tuple, CausalLMOutputWithPast]:
37
38
  r"""
38
39
  Args:
@@ -79,6 +80,7 @@ def lce_forward_deprecated(
79
80
  output_hidden_states=output_hidden_states,
80
81
  return_dict=return_dict,
81
82
  cache_position=cache_position,
83
+ **kwargs,
82
84
  )
83
85
 
84
86
  hidden_states = outputs[0]
@@ -86,7 +88,14 @@ def lce_forward_deprecated(
86
88
  loss = None
87
89
  logits = None
88
90
 
89
- if self.training and (labels is not None):
91
+ if skip_logits and labels is None:
92
+ raise ValueError("skip_logits is True, but labels is None")
93
+
94
+ if skip_logits is None:
95
+ # By default, if in training mode, don't materialize logits
96
+ skip_logits = self.training and labels is not None
97
+
98
+ if skip_logits:
90
99
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
100
  shift_labels = labels[..., 1:].contiguous()
92
101
 
@@ -134,14 +143,13 @@ def lce_forward_deprecated(
134
143
  )
135
144
 
136
145
 
137
- @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
138
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
146
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
139
147
  def lce_forward(
140
148
  self,
141
149
  input_ids: torch.LongTensor = None,
142
150
  attention_mask: Optional[torch.Tensor] = None,
143
151
  position_ids: Optional[torch.LongTensor] = None,
144
- past_key_values: Optional[HybridCache] = None,
152
+ past_key_values: Optional[Cache] = None,
145
153
  inputs_embeds: Optional[torch.FloatTensor] = None,
146
154
  labels: Optional[torch.LongTensor] = None,
147
155
  use_cache: Optional[bool] = None,
@@ -149,9 +157,10 @@ def lce_forward(
149
157
  output_hidden_states: Optional[bool] = None,
150
158
  return_dict: Optional[bool] = None,
151
159
  cache_position: Optional[torch.LongTensor] = None,
152
- num_logits_to_keep: int = 0,
153
- **loss_kwargs,
154
- ) -> Union[Tuple, CausalLMOutputWithPast]:
160
+ logits_to_keep: Union[int, torch.Tensor] = 0,
161
+ skip_logits: Optional[bool] = None,
162
+ **kwargs,
163
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
155
164
  r"""
156
165
  Args:
157
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -159,10 +168,12 @@ def lce_forward(
159
168
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
160
169
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
161
170
 
162
- num_logits_to_keep (`int`, *optional*):
163
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
171
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
172
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
164
173
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
165
174
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
175
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
176
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
166
177
 
167
178
  Returns:
168
179
 
@@ -205,42 +216,68 @@ def lce_forward(
205
216
  output_hidden_states=output_hidden_states,
206
217
  return_dict=return_dict,
207
218
  cache_position=cache_position,
219
+ **kwargs,
208
220
  )
209
221
 
210
222
  hidden_states = outputs[0]
223
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
224
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
225
+ kept_hidden_states = hidden_states[:, slice_indices, :]
211
226
 
227
+ shift_labels = kwargs.pop("shift_labels", None)
212
228
  logits = None
213
229
  loss = None
214
- # if in training mode, don't materialize logits
215
- if self.training and (labels is not None):
216
- loss = LigerForCausalLMLoss(
217
- hidden_states=hidden_states,
230
+ token_accuracy = None
231
+
232
+ if skip_logits and labels is None and shift_labels is None:
233
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
234
+
235
+ if skip_logits is None:
236
+ # By default, if in training mode, don't materialize logits
237
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
238
+
239
+ # Compute loss
240
+ if skip_logits:
241
+ result = LigerForCausalLMLoss(
242
+ hidden_states=kept_hidden_states,
218
243
  lm_head_weight=self.lm_head.weight,
219
244
  labels=labels,
245
+ shift_labels=shift_labels,
220
246
  hidden_size=self.config.hidden_size,
221
- softcap=self.config.final_logit_softcapping,
222
- **loss_kwargs,
247
+ final_logit_softcapping=self.config.final_logit_softcapping,
248
+ **kwargs,
223
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
224
251
 
225
- else: # if in inference mode materialize logits
226
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
252
+ else:
253
+ logits = self.lm_head(kept_hidden_states)
227
254
  if self.config.final_logit_softcapping is not None:
228
255
  logits = logits / self.config.final_logit_softcapping
229
256
  logits = torch.tanh(logits)
230
257
  logits = logits * self.config.final_logit_softcapping
231
258
 
232
259
  loss = None
233
- if labels is not None:
234
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
260
+ if labels is not None or shift_labels is not None:
261
+ loss = self.loss_function(
262
+ logits=logits,
263
+ labels=labels,
264
+ shift_labels=shift_labels,
265
+ vocab_size=self.vocab_size,
266
+ **kwargs,
267
+ )
235
268
 
236
269
  if not return_dict:
237
- output = (logits,) + outputs[1:]
238
- return (loss,) + output if loss is not None else output
270
+ output_tuple = (logits,) + outputs[1:]
271
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
272
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
273
+ return output_tuple
239
274
 
240
- return CausalLMOutputWithPast(
275
+ # Return custom output class with token_accuracy field
276
+ return LigerCausalLMOutputWithPast(
241
277
  loss=loss,
242
278
  logits=logits,
243
279
  past_key_values=outputs.past_key_values,
244
280
  hidden_states=outputs.hidden_states,
245
281
  attentions=outputs.attentions,
282
+ token_accuracy=token_accuracy,
246
283
  )