liger-kernel-nightly 0.5.6.dev20250403001329__py3-none-any.whl → 0.5.6.dev20250403190551__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
12
  from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
13
  from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
+ from transformers.utils.deprecation import deprecate_kwarg
15
16
 
16
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
18
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -127,6 +128,7 @@ def lce_forward_deprecated(
127
128
  )
128
129
 
129
130
 
131
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
132
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
131
133
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
134
  def lce_forward(
@@ -142,7 +144,7 @@ def lce_forward(
142
144
  output_hidden_states: Optional[bool] = None,
143
145
  return_dict: Optional[bool] = None,
144
146
  cache_position: Optional[torch.LongTensor] = None,
145
- num_logits_to_keep: int = 0,
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
146
148
  **loss_kwargs,
147
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
148
150
  r"""
@@ -152,10 +154,12 @@ def lce_forward(
152
154
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
153
155
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
154
156
 
155
- num_logits_to_keep (`int`, *optional*):
156
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
157
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
158
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
157
159
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
158
160
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
159
163
 
160
164
  Returns:
161
165
 
@@ -209,7 +213,8 @@ def lce_forward(
209
213
  **loss_kwargs,
210
214
  )
211
215
  else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
213
218
  if labels is not None:
214
219
  loss = self.loss_function(
215
220
  logits=logits,
@@ -13,6 +13,7 @@ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
13
  from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
19
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -134,6 +135,7 @@ def lce_forward_deprecated(
134
135
  )
135
136
 
136
137
 
138
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
137
139
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
138
140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
139
141
  def lce_forward(
@@ -149,7 +151,7 @@ def lce_forward(
149
151
  output_hidden_states: Optional[bool] = None,
150
152
  return_dict: Optional[bool] = None,
151
153
  cache_position: Optional[torch.LongTensor] = None,
152
- num_logits_to_keep: int = 0,
154
+ logits_to_keep: Union[int, torch.Tensor] = 0,
153
155
  **loss_kwargs,
154
156
  ) -> Union[Tuple, CausalLMOutputWithPast]:
155
157
  r"""
@@ -159,10 +161,12 @@ def lce_forward(
159
161
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
160
162
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
161
163
 
162
- num_logits_to_keep (`int`, *optional*):
163
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
164
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
165
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
164
166
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
165
167
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
168
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
169
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
166
170
 
167
171
  Returns:
168
172
 
@@ -223,7 +227,8 @@ def lce_forward(
223
227
  )
224
228
 
225
229
  else: # if in inference mode materialize logits
226
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
227
232
  if self.config.final_logit_softcapping is not None:
228
233
  logits = logits / self.config.final_logit_softcapping
229
234
  logits = torch.tanh(logits)
@@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
13
  from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
19
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -135,6 +136,7 @@ def lce_forward_deprecated(
135
136
  )
136
137
 
137
138
 
139
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
140
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
139
141
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
140
142
  def lce_forward(
@@ -150,7 +152,7 @@ def lce_forward(
150
152
  output_hidden_states: Optional[bool] = None,
151
153
  return_dict: Optional[bool] = None,
152
154
  cache_position: Optional[torch.LongTensor] = None,
153
- num_logits_to_keep: int = 0,
155
+ logits_to_keep: Union[int, torch.Tensor] = 0,
154
156
  **loss_kwargs,
155
157
  ) -> Union[Tuple, CausalLMOutputWithPast]:
156
158
  r"""
@@ -160,10 +162,12 @@ def lce_forward(
160
162
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
163
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
162
164
 
163
- num_logits_to_keep (`int`, *optional*):
164
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
165
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
166
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
165
167
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
166
168
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
169
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
170
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
167
171
 
168
172
  Returns:
169
173
 
@@ -222,7 +226,8 @@ def lce_forward(
222
226
  )
223
227
 
224
228
  else: # if in inference mode materialize logits
225
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
229
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
226
231
  if labels is not None:
227
232
  loss = self.loss_function(
228
233
  logits=logits,
@@ -5,17 +5,18 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
8
  from transformers.cache_utils import Cache
10
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
10
  from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
11
  from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
12
  from transformers.utils import add_start_docstrings_to_model_forward
14
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
15
15
 
16
16
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
17
 
18
18
 
19
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
20
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
20
21
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
22
  def lce_forward(
@@ -31,6 +32,7 @@ def lce_forward(
31
32
  output_hidden_states: Optional[bool] = None,
32
33
  return_dict: Optional[bool] = None,
33
34
  cache_position: Optional[torch.LongTensor] = None,
35
+ logits_to_keep: Union[int, torch.Tensor] = 0,
34
36
  **loss_kwargs,
35
37
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
38
  r"""
@@ -43,6 +45,12 @@ def lce_forward(
43
45
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
46
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
47
 
48
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
49
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
50
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
51
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
52
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
53
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
54
  Returns:
47
55
 
48
56
  Example:
@@ -97,21 +105,17 @@ def lce_forward(
97
105
  )
98
106
 
99
107
  else:
100
- logits = self.lm_head(hidden_states)
101
- if labels is not None:
102
- # Upcast to float if we need to compute the loss to avoid potential precision issues
103
- logits = logits.float()
104
- # Shift so that tokens < n predict n
105
- shift_logits = logits[..., :-1, :].contiguous()
106
- shift_labels = labels[..., 1:].contiguous()
107
- # Flatten the tokens
108
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
109
- shift_labels = shift_labels.view(-1)
110
- # Ensure tensors are on the same device
111
- shift_labels = shift_labels.to(shift_logits.device)
112
- loss_fct = CrossEntropyLoss()
113
- loss = loss_fct(shift_logits, shift_labels)
108
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
114
110
 
111
+ loss = None
112
+ if labels is not None:
113
+ loss = self.loss_function(
114
+ logits=logits,
115
+ labels=labels,
116
+ vocab_size=self.config.vocab_size,
117
+ **loss_kwargs,
118
+ )
115
119
  if not return_dict:
116
120
  output = (logits,) + outputs[1:]
117
121
  return (loss,) + output if loss is not None else output
@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRIN
12
12
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
13
  from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
+ from transformers.utils.deprecation import deprecate_kwarg
15
16
 
16
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
18
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -144,6 +145,7 @@ def lce_forward_deprecated(
144
145
  )
145
146
 
146
147
 
148
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
147
149
  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
148
150
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
149
151
  # Ignore copy
@@ -161,7 +163,7 @@ def lce_forward(
161
163
  output_router_logits: Optional[bool] = None,
162
164
  return_dict: Optional[bool] = None,
163
165
  cache_position: Optional[torch.LongTensor] = None,
164
- num_logits_to_keep: int = 0,
166
+ logits_to_keep: Union[int, torch.Tensor] = 0,
165
167
  **loss_kwargs,
166
168
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
167
169
  r"""
@@ -171,10 +173,12 @@ def lce_forward(
171
173
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
172
174
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
173
175
 
174
- num_logits_to_keep (`int`, *optional*):
175
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
176
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
177
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
176
178
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
177
179
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
180
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
181
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
178
182
 
179
183
  Returns:
180
184
 
@@ -235,15 +239,12 @@ def lce_forward(
235
239
  )
236
240
 
237
241
  else: # if in inference mode materialize logits
238
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
239
- if labels is not None:
240
- loss = self.loss_function(
241
- logits=logits,
242
- labels=labels,
243
- vocab_size=self.config.vocab_size,
244
- **loss_kwargs,
245
- )
242
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
246
244
 
245
+ loss = None
246
+ if labels is not None:
247
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
247
248
  aux_loss = None
248
249
  if output_router_logits:
249
250
  aux_loss = load_balancing_loss_func(
@@ -11,6 +11,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
11
11
  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -133,6 +134,7 @@ def lce_forward_deprecated(
133
134
  )
134
135
 
135
136
 
137
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
136
138
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
137
139
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
138
140
  def lce_forward(
@@ -151,7 +153,7 @@ def lce_forward(
151
153
  output_hidden_states: Optional[bool] = None,
152
154
  return_dict: Optional[bool] = None,
153
155
  cache_position: Optional[torch.LongTensor] = None,
154
- num_logits_to_keep: int = 0,
156
+ logits_to_keep: Union[int, torch.Tensor] = 0,
155
157
  **loss_kwargs,
156
158
  ) -> Union[Tuple, CausalLMOutputWithPast]:
157
159
  r"""
@@ -161,10 +163,12 @@ def lce_forward(
161
163
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
162
164
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
163
165
 
164
- num_logits_to_keep (`int`, *optional*):
165
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
166
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
167
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
166
168
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
167
169
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
170
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
171
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
168
172
 
169
173
  Returns:
170
174
 
@@ -225,7 +229,8 @@ def lce_forward(
225
229
  )
226
230
 
227
231
  else: # if in inference mode materialize logits
228
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
232
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
229
234
  if labels is not None:
230
235
  loss = self.loss_function(
231
236
  logits=logits,
@@ -10,10 +10,12 @@ from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
10
  from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
11
  from transformers.utils import add_start_docstrings_to_model_forward
12
12
  from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
13
14
 
14
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
16
 
16
17
 
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
19
  @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
18
20
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
19
21
  def lce_forward(
@@ -29,7 +31,7 @@ def lce_forward(
29
31
  output_hidden_states: Optional[bool] = None,
30
32
  return_dict: Optional[bool] = None,
31
33
  cache_position: Optional[torch.LongTensor] = None,
32
- num_logits_to_keep: int = 0,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
33
35
  **loss_kwargs,
34
36
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
37
  r"""
@@ -39,10 +41,12 @@ def lce_forward(
39
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
43
 
42
- num_logits_to_keep (`int`, *optional*):
43
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
44
46
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
45
47
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
50
 
47
51
  Returns:
48
52
 
@@ -98,7 +102,8 @@ def lce_forward(
98
102
  )
99
103
 
100
104
  else: # if in inference mode materialize logits
101
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
105
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
102
107
  if labels is not None:
103
108
  loss = self.loss_function(
104
109
  logits=logits,
@@ -11,6 +11,7 @@ from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
11
  from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -126,6 +127,7 @@ def lce_forward_deprecated(
126
127
  )
127
128
 
128
129
 
130
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
129
131
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
130
132
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
131
133
  def lce_forward(
@@ -141,7 +143,7 @@ def lce_forward(
141
143
  output_hidden_states: Optional[bool] = None,
142
144
  return_dict: Optional[bool] = None,
143
145
  cache_position: Optional[torch.LongTensor] = None,
144
- num_logits_to_keep: int = 0,
146
+ logits_to_keep: Union[int, torch.Tensor] = 0,
145
147
  **loss_kwargs,
146
148
  ) -> Union[Tuple, CausalLMOutputWithPast]:
147
149
  r"""
@@ -151,10 +153,12 @@ def lce_forward(
151
153
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
152
154
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
153
155
 
154
- num_logits_to_keep (`int`, *optional*):
155
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
156
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
157
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
156
158
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
157
159
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
160
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
161
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
158
162
 
159
163
  Returns:
160
164
 
@@ -223,7 +227,8 @@ def lce_forward(
223
227
  )
224
228
 
225
229
  else: # if in inference mode materialize logits
226
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
227
232
  if labels is not None:
228
233
  loss = self.loss_function(
229
234
  logits=logits,
@@ -11,6 +11,7 @@ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
11
  from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -125,6 +126,7 @@ def lce_forward_deprecated(
125
126
  )
126
127
 
127
128
 
129
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
128
130
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
129
131
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
130
132
  def lce_forward(
@@ -140,7 +142,7 @@ def lce_forward(
140
142
  output_hidden_states: Optional[bool] = None,
141
143
  return_dict: Optional[bool] = None,
142
144
  cache_position: Optional[torch.LongTensor] = None,
143
- num_logits_to_keep: int = 0,
145
+ logits_to_keep: Union[int, torch.Tensor] = 0,
144
146
  **loss_kwargs,
145
147
  ) -> Union[Tuple, CausalLMOutputWithPast]:
146
148
  r"""
@@ -150,10 +152,12 @@ def lce_forward(
150
152
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
151
153
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
152
154
 
153
- num_logits_to_keep (`int`, *optional*):
154
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
155
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
156
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
155
157
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
156
158
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
159
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
160
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
157
161
 
158
162
  Returns:
159
163
 
@@ -209,7 +213,8 @@ def lce_forward(
209
213
  )
210
214
 
211
215
  else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
213
218
  if labels is not None:
214
219
  loss = self.loss_function(
215
220
  logits=logits,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250403001329
3
+ Version: 0.5.6.dev20250403190551
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -55,28 +55,28 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
55
55
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
56
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
57
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
- liger_kernel/transformers/model/gemma.py,sha256=7cBTljzh-8_ACBhYl6NUfj5_ux92YRlmnAU5gfDAQAI,9312
59
- liger_kernel/transformers/model/gemma2.py,sha256=X0FOIhvFlTrmWI7Ws06wUkutgHW3lWtLOnnHp1NgZ3A,10403
58
+ liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
+ liger_kernel/transformers/model/gemma2.py,sha256=tLl1v-O8K0NZ7BQcSf1dE3450-xV72RAk4E5oTPcu_s,10907
60
60
  liger_kernel/transformers/model/gemma3.py,sha256=PjAfFtupT9EW0sb57Hx8UJXcnvq9HFgNndeAE4EqyPw,16086
61
- liger_kernel/transformers/model/llama.py,sha256=d9rBaK8e8RSMCFHdgom9ZHuXOlnh6U_o-GkAFGRNGOY,9989
61
+ liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
62
62
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
63
  liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
64
- liger_kernel/transformers/model/mistral.py,sha256=o7tyl1sPWPfZwwrBLRlryHlSI8I55viuJoMI5Bh5Nww,5014
65
- liger_kernel/transformers/model/mixtral.py,sha256=T0ITv2-PkR8VErVOVUizoS4EzjmARyR7GFh0tXDB_i4,11089
66
- liger_kernel/transformers/model/mllama.py,sha256=RCKtwnGOMFYIbtt1zUQ15Cyv4eNpHkTWcgkmG2EEs2I,10804
67
- liger_kernel/transformers/model/olmo2.py,sha256=5M8kczp4D-jvbjcV7cKATIJGF34xd-Rs-PPdKZWSIlY,4685
64
+ liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
+ liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
+ liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
67
+ liger_kernel/transformers/model/olmo2.py,sha256=rSzSALikEGkk0w3PLNQPrqg-ioN8TpWCXkAlg3LtCdI,5189
68
68
  liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
69
- liger_kernel/transformers/model/phi3.py,sha256=NmU2DuU1Huwha6K7YSsJCnvQfUovTTGlsfBZhbx0UoI,9951
70
- liger_kernel/transformers/model/qwen2.py,sha256=t7NotBHoebsPqNSxwaf9DXTg8jxgB5BdunSGqYOE0hQ,9240
69
+ liger_kernel/transformers/model/phi3.py,sha256=ebITCrmwmb4z66CbSrZl1kD6BsP52IcSAR8uwUTp9nc,10455
70
+ liger_kernel/transformers/model/qwen2.py,sha256=QaoTDrJv2wIuAM8QMoeWVvgNl0N5gHzIrew9QGG7kXc,9744
71
71
  liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
72
72
  liger_kernel/transformers/model/qwen2_vl.py,sha256=zo4O9fShNHYqSLrzLGqQYWSMtJI6UHaSY7zvMCYWyD8,9685
73
73
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/METADATA,sha256=XTYWQ-SEGTr7X52X8TIIceAeNhIYMf3lzRf7LXP1vHM,23297
79
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.6.dev20250403001329.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/METADATA,sha256=-mDZqil7nMfEwTxTcAefjnMBRz-OZQ3iJamSYCnKRps,23297
79
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD,,