liger-kernel 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
68
  compute_nll_loss=False,
69
69
  compiled=True,
70
70
  use_ref_model=True,
71
+ average_log_prob=False,
71
72
  chunk_size=1,
72
73
  ):
73
74
  """
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
86
  compute_nll_loss (bool): Whether to compute the NLL loss
86
87
  compiled (bool): Whether to use torch compile
87
88
  use_ref_model (bool): Whether to use a reference model
89
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
90
  chunk_size (int): Size of chunks for processing.
89
91
  Returns:
90
92
  torch.Tensor: Computed loss
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
106
  ref_input=ref_input,
105
107
  ref_weight=ref_weight,
106
108
  ref_bias=ref_bias,
109
+ average_log_prob=average_log_prob,
107
110
  chunk_size=chunk_size,
108
111
  )
109
112
 
110
113
  @staticmethod
111
114
  def backward(ctx, *grad_output):
112
115
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
116
+ return *grads, None, None, None, None, None, None, None, None, None, None
114
117
 
115
118
 
116
119
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
128
  compute_nll_loss: bool = False,
126
129
  compiled: bool = True,
127
130
  use_ref_model: bool = True,
131
+ average_log_prob: bool = True,
128
132
  chunk_size: int = 1,
129
133
  ):
130
134
  """
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
138
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
139
  compiled (bool): Whether to use the torch compiled kernel.
136
140
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
141
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
142
  chunk_size (int): Size of chunks for processing.
138
143
  """
139
144
  super().__init__()
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
147
  self.compute_nll_loss = compute_nll_loss
143
148
  self.compiled = compiled
144
149
  self.use_ref_model = use_ref_model
150
+ self.average_log_prob = average_log_prob
145
151
  self.chunk_size = chunk_size
146
152
 
147
153
  def forward(
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
173
  self.compute_nll_loss,
168
174
  self.compiled,
169
175
  self.use_ref_model,
176
+ self.average_log_prob,
170
177
  self.chunk_size,
171
178
  )
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
351
351
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
352
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
353
  pass
354
-
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
358
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
359
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
360
  else:
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
143
143
  alpha=1.0,
144
144
  )
145
145
 
146
- if reduction == "none":
147
- loss = loss_1d
148
- z_loss = z_loss_1d if return_z_loss else None
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
149
150
 
150
151
  else:
151
152
  loss = torch.sum(loss_1d)
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
26
26
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
27
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
28
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
29
30
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
38
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
39
40
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
40
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
42
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
41
43
 
42
44
 
43
45
  # Check if 'transformers' is installed
@@ -79,6 +81,7 @@ def __getattr__(name: str):
79
81
  "apply_liger_kernel_to_gemma2",
80
82
  "apply_liger_kernel_to_gemma3",
81
83
  "apply_liger_kernel_to_gemma3_text",
84
+ "apply_liger_kernel_to_glm4",
82
85
  "apply_liger_kernel_to_granite",
83
86
  "apply_liger_kernel_to_llama",
84
87
  "apply_liger_kernel_to_llava",
@@ -91,6 +94,7 @@ def __getattr__(name: str):
91
94
  "apply_liger_kernel_to_qwen2",
92
95
  "apply_liger_kernel_to_qwen2_5_vl",
93
96
  "apply_liger_kernel_to_qwen2_vl",
97
+ "apply_liger_kernel_to_qwen3",
94
98
  }
95
99
 
96
100
  if name in monkey_patch_symbols:
@@ -129,6 +133,7 @@ if _TRANSFORMERS_AVAILABLE:
129
133
  "apply_liger_kernel_to_gemma2",
130
134
  "apply_liger_kernel_to_gemma3",
131
135
  "apply_liger_kernel_to_gemma3_text",
136
+ "apply_liger_kernel_to_glm4",
132
137
  "apply_liger_kernel_to_granite",
133
138
  "apply_liger_kernel_to_llama",
134
139
  "apply_liger_kernel_to_llava",
@@ -141,5 +146,6 @@ if _TRANSFORMERS_AVAILABLE:
141
146
  "apply_liger_kernel_to_qwen2",
142
147
  "apply_liger_kernel_to_qwen2_5_vl",
143
148
  "apply_liger_kernel_to_qwen2_vl",
149
+ "apply_liger_kernel_to_qwen3",
144
150
  ]
145
151
  )
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
23
  assert reduction in {
24
24
  "mean",
25
25
  "sum",
26
- "none",
27
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
+ }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
28
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
28
  self.ce_weight = ce_weight
30
29
  self.ignore_index = ignore_index
@@ -200,21 +200,25 @@ def lce_forward(
200
200
  )
201
201
 
202
202
  hidden_states = outputs[0]
203
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
204
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
205
+ kept_hidden_states = hidden_states[:, slice_indices, :]
203
206
 
207
+ shift_labels = loss_kwargs.pop("shift_labels", None)
204
208
  logits = None
205
209
  loss = None
206
210
  # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
211
+ if self.training and (labels is not None or shift_labels is not None):
208
212
  loss = LigerForCausalLMLoss(
209
- hidden_states=hidden_states,
213
+ hidden_states=kept_hidden_states,
210
214
  lm_head_weight=self.lm_head.weight,
211
215
  labels=labels,
216
+ shift_labels=shift_labels,
212
217
  hidden_size=self.config.hidden_size,
213
218
  **loss_kwargs,
214
219
  )
215
220
  else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
221
+ logits = self.lm_head(kept_hidden_states)
218
222
  if labels is not None:
219
223
  loss = self.loss_function(
220
224
  logits=logits,
@@ -212,23 +212,27 @@ def lce_forward(
212
212
  )
213
213
 
214
214
  hidden_states = outputs[0]
215
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ kept_hidden_states = hidden_states[:, slice_indices, :]
215
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
220
  logits = None
217
221
  loss = None
218
222
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
220
224
  loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
225
+ hidden_states=kept_hidden_states,
222
226
  lm_head_weight=self.lm_head.weight,
223
227
  labels=labels,
228
+ shift_labels=shift_labels,
224
229
  hidden_size=self.config.hidden_size,
225
230
  final_logit_softcapping=self.config.final_logit_softcapping,
226
231
  **loss_kwargs,
227
232
  )
228
233
 
229
234
  else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
235
+ logits = self.lm_head(kept_hidden_states)
232
236
  if self.config.final_logit_softcapping is not None:
233
237
  logits = logits / self.config.final_logit_softcapping
234
238
  logits = torch.tanh(logits)
@@ -104,13 +104,15 @@ def causal_forward(
104
104
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
105
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
106
  kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ shift_labels = loss_kwargs.pop("shift_labels", None)
107
108
  loss = None
108
109
  logits = None
109
- if self.training and (labels is not None):
110
+ if self.training and (labels is not None or shift_labels is not None):
110
111
  loss = LigerForCausalLMLoss(
111
112
  hidden_states=kept_hidden_states,
112
113
  lm_head_weight=self.lm_head.weight,
113
114
  labels=labels,
115
+ shift_labels=shift_labels,
114
116
  hidden_size=self.config.hidden_size,
115
117
  final_logit_softcapping=self.config.final_logit_softcapping,
116
118
  **loss_kwargs,
@@ -0,0 +1,125 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
+ from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
14
+
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+
17
+
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
+ @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
+ def lce_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
35
+ **loss_kwargs,
36
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ r"""
38
+ Args:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
50
+
51
+ Returns:
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
57
+
58
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
59
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
60
+
61
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
62
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
63
+
64
+ >>> # Generate
65
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
66
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
67
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
68
+ ```
69
+ """
70
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
77
+ outputs = self.model(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
94
+
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
96
+ logits = None
97
+ loss = None
98
+ # if in training mode, don't materialize logits
99
+ if self.training and (labels is not None or shift_labels is not None):
100
+ loss = LigerForCausalLMLoss(
101
+ hidden_states=kept_hidden_states,
102
+ lm_head_weight=self.lm_head.weight,
103
+ labels=labels,
104
+ shift_labels=shift_labels,
105
+ hidden_size=self.config.hidden_size,
106
+ **loss_kwargs,
107
+ )
108
+
109
+ else: # if in inference mode materialize logits
110
+ logits = self.lm_head(kept_hidden_states)
111
+ if labels is not None:
112
+ loss = self.loss_function(
113
+ logits=logits,
114
+ labels=labels,
115
+ vocab_size=self.config.vocab_size,
116
+ **loss_kwargs,
117
+ )
118
+
119
+ return CausalLMOutputWithPast(
120
+ loss=loss,
121
+ logits=logits,
122
+ past_key_values=outputs.past_key_values,
123
+ hidden_states=outputs.hidden_states,
124
+ attentions=outputs.attentions,
125
+ )
@@ -209,25 +209,29 @@ def lce_forward(
209
209
  )
210
210
 
211
211
  hidden_states = outputs[0]
212
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
213
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
214
+ kept_hidden_states = hidden_states[:, slice_indices, :]
212
215
 
213
216
  if self.config.pretraining_tp > 1:
214
217
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
220
  logits = None
217
221
  loss = None
218
222
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
220
224
  loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
225
+ hidden_states=kept_hidden_states,
222
226
  lm_head_weight=self.lm_head.weight,
223
227
  labels=labels,
228
+ shift_labels=shift_labels,
224
229
  hidden_size=self.config.hidden_size,
225
230
  **loss_kwargs,
226
231
  )
227
232
 
228
233
  else: # if in inference mode materialize logits
229
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
- logits = self.lm_head(hidden_states[:, slice_indices, :])
234
+ logits = self.lm_head(kept_hidden_states)
231
235
  if labels is not None:
232
236
  loss = self.loss_function(
233
237
  logits=logits,
@@ -91,22 +91,26 @@ def lce_forward(
91
91
  )
92
92
 
93
93
  hidden_states = outputs[0]
94
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
95
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
96
+ kept_hidden_states = hidden_states[:, slice_indices, :]
94
97
 
98
+ shift_labels = loss_kwargs.pop("shift_labels", None)
95
99
  loss = None
96
100
  logits = None
97
101
 
98
- if self.training and (labels is not None):
102
+ if self.training and (labels is not None or shift_labels is not None):
99
103
  loss = LigerForCausalLMLoss(
100
- hidden_states=hidden_states,
104
+ hidden_states=kept_hidden_states,
101
105
  lm_head_weight=self.lm_head.weight,
102
106
  labels=labels,
107
+ shift_labels=shift_labels,
103
108
  hidden_size=self.config.hidden_size,
104
109
  **loss_kwargs,
105
110
  )
106
111
 
107
112
  else:
108
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109
- logits = self.lm_head(hidden_states[:, slice_indices, :])
113
+ logits = self.lm_head(kept_hidden_states)
110
114
 
111
115
  loss = None
112
116
  if labels is not None:
@@ -225,22 +225,26 @@ def lce_forward(
225
225
  )
226
226
 
227
227
  hidden_states = outputs[0]
228
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
229
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
+ kept_hidden_states = hidden_states[:, slice_indices, :]
228
231
 
232
+ shift_labels = loss_kwargs.pop("shift_labels", None)
229
233
  logits = None
230
234
  loss = None
231
235
  # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
236
+ if self.training and (labels is not None or shift_labels is not None):
233
237
  loss = LigerForCausalLMLoss(
234
- hidden_states=hidden_states,
238
+ hidden_states=kept_hidden_states,
235
239
  lm_head_weight=self.lm_head.weight,
236
240
  labels=labels,
241
+ shift_labels=shift_labels,
237
242
  hidden_size=self.config.hidden_size,
238
243
  **loss_kwargs,
239
244
  )
240
245
 
241
246
  else: # if in inference mode materialize logits
242
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
- logits = self.lm_head(hidden_states[:, slice_indices, :])
247
+ logits = self.lm_head(kept_hidden_states)
244
248
 
245
249
  loss = None
246
250
  if labels is not None:
@@ -215,22 +215,26 @@ def lce_forward(
215
215
  )
216
216
 
217
217
  hidden_states = outputs[0]
218
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
219
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
220
+ kept_hidden_states = hidden_states[:, slice_indices, :]
218
221
 
222
+ shift_labels = loss_kwargs.pop("shift_labels", None)
219
223
  logits = None
220
224
  loss = None
221
225
  # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
226
+ if self.training and (labels is not None or shift_labels is not None):
223
227
  loss = LigerForCausalLMLoss(
224
- hidden_states=hidden_states,
228
+ hidden_states=kept_hidden_states,
225
229
  lm_head_weight=self.lm_head.weight,
226
230
  labels=labels,
231
+ shift_labels=shift_labels,
227
232
  hidden_size=self.config.hidden_size,
228
233
  **loss_kwargs,
229
234
  )
230
235
 
231
236
  else: # if in inference mode materialize logits
232
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
- logits = self.lm_head(hidden_states[:, slice_indices, :])
237
+ logits = self.lm_head(kept_hidden_states)
234
238
  if labels is not None:
235
239
  loss = self.loss_function(
236
240
  logits=logits,
@@ -88,22 +88,26 @@ def lce_forward(
88
88
  )
89
89
 
90
90
  hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
94
 
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
92
96
  logits = None
93
97
  loss = None
94
98
  # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
99
+ if self.training and (labels is not None or shift_labels is not None):
96
100
  loss = LigerForCausalLMLoss(
97
- hidden_states=hidden_states,
101
+ hidden_states=kept_hidden_states,
98
102
  lm_head_weight=self.lm_head.weight,
99
103
  labels=labels,
104
+ shift_labels=shift_labels,
100
105
  hidden_size=self.config.hidden_size,
101
106
  **loss_kwargs,
102
107
  )
103
108
 
104
109
  else: # if in inference mode materialize logits
105
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
- logits = self.lm_head(hidden_states[:, slice_indices, :])
110
+ logits = self.lm_head(kept_hidden_states)
107
111
  if labels is not None:
108
112
  loss = self.loss_function(
109
113
  logits=logits,
@@ -213,22 +213,26 @@ def lce_forward(
213
213
  )
214
214
 
215
215
  hidden_states = outputs[0]
216
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
217
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
218
+ kept_hidden_states = hidden_states[:, slice_indices, :]
216
219
 
220
+ shift_labels = loss_kwargs.pop("shift_labels", None)
217
221
  logits = None
218
222
  loss = None
219
223
  # if in training mode, don't materialize logits
220
- if self.training and (labels is not None):
224
+ if self.training and (labels is not None or shift_labels is not None):
221
225
  loss = LigerForCausalLMLoss(
222
- hidden_states=hidden_states,
226
+ hidden_states=kept_hidden_states,
223
227
  lm_head_weight=self.lm_head.weight,
224
228
  labels=labels,
229
+ shift_labels=shift_labels,
225
230
  hidden_size=self.config.hidden_size,
226
231
  **loss_kwargs,
227
232
  )
228
233
 
229
234
  else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
235
+ logits = self.lm_head(kept_hidden_states)
232
236
  if labels is not None:
233
237
  loss = self.loss_function(
234
238
  logits=logits,
@@ -199,22 +199,26 @@ def lce_forward(
199
199
  )
200
200
 
201
201
  hidden_states = outputs[0]
202
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
203
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
204
+ kept_hidden_states = hidden_states[:, slice_indices, :]
202
205
 
206
+ shift_labels = loss_kwargs.pop("shift_labels", None)
203
207
  logits = None
204
208
  loss = None
205
209
  # if in training mode, don't materialize logits
206
- if self.training and (labels is not None):
210
+ if self.training and (labels is not None or shift_labels is not None):
207
211
  loss = LigerForCausalLMLoss(
208
- hidden_states=hidden_states,
212
+ hidden_states=kept_hidden_states,
209
213
  lm_head_weight=self.lm_head.weight,
210
214
  labels=labels,
215
+ shift_labels=shift_labels,
211
216
  hidden_size=self.config.hidden_size,
212
217
  **loss_kwargs,
213
218
  )
214
219
 
215
220
  else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
221
+ logits = self.lm_head(kept_hidden_states)
218
222
  if labels is not None:
219
223
  loss = self.loss_function(
220
224
  logits=logits,
@@ -163,14 +163,16 @@ def lce_forward(
163
163
 
164
164
  hidden_states = outputs[0]
165
165
 
166
+ shift_labels = loss_kwargs.pop("shift_labels", None)
166
167
  loss = None
167
168
  logits = None
168
169
 
169
- if self.training and (labels is not None):
170
+ if self.training and (labels is not None or shift_labels is not None):
170
171
  loss = LigerForCausalLMLoss(
171
172
  hidden_states=hidden_states,
172
173
  lm_head_weight=self.lm_head.weight,
173
174
  labels=labels,
175
+ shift_labels=shift_labels,
174
176
  hidden_size=self.config.hidden_size,
175
177
  **loss_kwargs,
176
178
  )
@@ -167,14 +167,16 @@ def lce_forward(
167
167
 
168
168
  hidden_states = outputs[0]
169
169
 
170
+ shift_labels = loss_kwargs.pop("shift_labels", None)
170
171
  loss = None
171
172
  logits = None
172
173
 
173
- if self.training and (labels is not None):
174
+ if self.training and (labels is not None or shift_labels is not None):
174
175
  loss = LigerForCausalLMLoss(
175
176
  hidden_states=hidden_states,
176
177
  lm_head_weight=self.lm_head.weight,
177
178
  labels=labels,
179
+ shift_labels=shift_labels,
178
180
  hidden_size=self.config.hidden_size,
179
181
  **loss_kwargs,
180
182
  )
@@ -0,0 +1,118 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.qwen3.modeling_qwen3 import _CONFIG_FOR_DOC
9
+ from transformers.models.qwen3.modeling_qwen3 import QWEN3_INPUTS_DOCSTRING
10
+ from transformers.utils import add_start_docstrings_to_model_forward
11
+ from transformers.utils import replace_return_docstrings
12
+
13
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+
15
+
16
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
17
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
18
+ def lce_forward(
19
+ self,
20
+ input_ids: Optional[torch.LongTensor] = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.LongTensor] = None,
23
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
24
+ inputs_embeds: Optional[torch.FloatTensor] = None,
25
+ labels: Optional[torch.LongTensor] = None,
26
+ use_cache: Optional[bool] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ **kwargs,
32
+ ) -> CausalLMOutputWithPast:
33
+ r"""
34
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
35
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
36
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
37
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
38
+
39
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
40
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
41
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
42
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
43
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
44
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
45
+
46
+ Returns:
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
52
+
53
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
54
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
55
+
56
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
57
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
58
+
59
+ >>> # Generate
60
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
61
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
62
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
63
+ ```"""
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
+ output_hidden_states = (
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
+ )
68
+
69
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
70
+ outputs = self.model(
71
+ input_ids=input_ids,
72
+ attention_mask=attention_mask,
73
+ position_ids=position_ids,
74
+ past_key_values=past_key_values,
75
+ inputs_embeds=inputs_embeds,
76
+ use_cache=use_cache,
77
+ output_attentions=output_attentions,
78
+ output_hidden_states=output_hidden_states,
79
+ cache_position=cache_position,
80
+ **kwargs,
81
+ )
82
+
83
+ hidden_states = outputs[0]
84
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
85
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
86
+ kept_hidden_states = hidden_states[:, slice_indices, :]
87
+
88
+ shift_labels = kwargs.pop("shift_labels", None)
89
+ logits = None
90
+ loss = None
91
+ # if in training mode, don't materialize logits
92
+ if self.training and (labels is not None or shift_labels is not None):
93
+ loss = LigerForCausalLMLoss(
94
+ hidden_states=kept_hidden_states,
95
+ lm_head_weight=self.lm_head.weight,
96
+ labels=labels,
97
+ shift_labels=shift_labels,
98
+ hidden_size=self.config.hidden_size,
99
+ **kwargs,
100
+ )
101
+
102
+ else: # if in inference mode materialize logits
103
+ logits = self.lm_head(kept_hidden_states)
104
+ if labels is not None:
105
+ loss = self.loss_function(
106
+ logits=logits,
107
+ labels=labels,
108
+ vocab_size=self.config.vocab_size,
109
+ **kwargs,
110
+ )
111
+
112
+ return CausalLMOutputWithPast(
113
+ loss=loss,
114
+ logits=logits,
115
+ past_key_values=outputs.past_key_values,
116
+ hidden_states=outputs.hidden_states,
117
+ attentions=outputs.attentions,
118
+ )
@@ -1048,6 +1048,60 @@ def apply_liger_kernel_to_qwen2(
1048
1048
  print("Applied Liger kernels to Qwen2")
1049
1049
 
1050
1050
 
1051
+ def apply_liger_kernel_to_qwen3(
1052
+ rope: bool = True,
1053
+ cross_entropy: bool = False,
1054
+ fused_linear_cross_entropy: bool = True,
1055
+ rms_norm: bool = True,
1056
+ swiglu: bool = True,
1057
+ model: PreTrainedModel = None,
1058
+ ) -> None:
1059
+ """
1060
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1061
+ """
1062
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1063
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1064
+ )
1065
+
1066
+ from transformers.models.qwen3 import modeling_qwen3
1067
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1068
+
1069
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1070
+
1071
+ if rope:
1072
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1073
+
1074
+ if rms_norm:
1075
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1076
+
1077
+ if cross_entropy:
1078
+ from transformers.loss.loss_utils import nn
1079
+
1080
+ nn.functional.cross_entropy = liger_cross_entropy
1081
+
1082
+ if fused_linear_cross_entropy:
1083
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1084
+
1085
+ if swiglu:
1086
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1087
+
1088
+ if model is not None:
1089
+ # The model instance already exists, so we need to additionally patch the
1090
+ # instance variables that reference already-instantiated modules
1091
+
1092
+ # get the base model from the model instance
1093
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1094
+
1095
+ if rms_norm:
1096
+ _patch_rms_norm_module(base_model.norm)
1097
+ for decoder_layer in base_model.layers:
1098
+ if swiglu:
1099
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1100
+ if rms_norm:
1101
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1102
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1103
+
1104
+
1051
1105
  def apply_liger_kernel_to_qwen2_vl(
1052
1106
  rope: bool = True,
1053
1107
  cross_entropy: bool = False,
@@ -1319,12 +1373,78 @@ def apply_liger_kernel_to_olmo2(
1319
1373
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1320
1374
 
1321
1375
 
1376
+ def apply_liger_kernel_to_glm4(
1377
+ rope: bool = False,
1378
+ cross_entropy: bool = False,
1379
+ fused_linear_cross_entropy: bool = True,
1380
+ rms_norm: bool = True,
1381
+ swiglu: bool = True,
1382
+ model: PreTrainedModel = None,
1383
+ ) -> None:
1384
+ """
1385
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1386
+
1387
+ Args:
1388
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1389
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1390
+ fused_linear_cross_entropy (bool):
1391
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1392
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1393
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1394
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1395
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1396
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1397
+ loaded. Default is None.
1398
+ """
1399
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1400
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1401
+ )
1402
+
1403
+ from transformers.models.glm4 import modeling_glm4
1404
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
1405
+
1406
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1407
+
1408
+ if rope:
1409
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1410
+ if rms_norm:
1411
+ modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1412
+ if swiglu:
1413
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1414
+ if cross_entropy:
1415
+ from transformers.loss.loss_utils import nn
1416
+
1417
+ nn.functional.cross_entropy = liger_cross_entropy
1418
+ if fused_linear_cross_entropy:
1419
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1420
+
1421
+ if model is not None:
1422
+ # The model instance already exists, so we need to additionally patch the
1423
+ # instance variables that reference already-instantiated modules
1424
+
1425
+ # get the base model from the model instance
1426
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
1427
+
1428
+ if rms_norm:
1429
+ _patch_rms_norm_module(base_model.norm, in_place=False)
1430
+
1431
+ for decoder_layer in base_model.layers:
1432
+ if swiglu:
1433
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1434
+ if rms_norm:
1435
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
1436
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1437
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
1438
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1439
+
1440
+
1322
1441
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1323
1442
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1324
1443
  "gemma": apply_liger_kernel_to_gemma,
1325
1444
  "gemma2": apply_liger_kernel_to_gemma2,
1326
1445
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
1446
  "gemma3": apply_liger_kernel_to_gemma3,
1447
+ "glm4": apply_liger_kernel_to_glm4,
1328
1448
  "llama": apply_liger_kernel_to_llama,
1329
1449
  "llava": apply_liger_kernel_to_llava,
1330
1450
  "granite": apply_liger_kernel_to_granite,
@@ -1334,6 +1454,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1334
1454
  "mixtral": apply_liger_kernel_to_mixtral,
1335
1455
  "olmo2": apply_liger_kernel_to_olmo2,
1336
1456
  "qwen2": apply_liger_kernel_to_qwen2,
1457
+ "qwen3": apply_liger_kernel_to_qwen3,
1337
1458
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1338
1459
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1339
1460
  "phi3": apply_liger_kernel_to_phi3,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -320,9 +320,11 @@ loss.backward()
320
320
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
321
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
322
322
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
+ | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
324
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
324
325
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
325
326
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
327
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
326
328
 
327
329
 
328
330
  ## Low-level APIs
@@ -4,7 +4,7 @@ liger_kernel/utils.py,sha256=178Hn8uD-VauDT6FjqMyXLbKLod8ObIpaTtapHwfEK0,1861
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
- liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
7
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=Xypt4FoTSmAnJE4SWtsCv4aNHK4ToR1LonUQtCTEuHQ,6258
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
10
  liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
@@ -16,9 +16,9 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
19
+ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
20
20
  liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
21
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
21
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
22
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
23
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -33,12 +33,12 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
33
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
34
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
35
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
36
- liger_kernel/transformers/__init__.py,sha256=SH30Pt2ZqyQY-mmWQldg_r-5koowuymTIoU4F4e1KHk,6419
36
+ liger_kernel/transformers/__init__.py,sha256=x_3CYHJt-xj4va3N32kfwf000F-DNBtj-YE6OylDAW8,6774
37
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
38
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
39
39
  liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
40
  liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
41
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
41
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
42
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
43
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
44
44
  liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
@@ -46,7 +46,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
46
46
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
47
47
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
48
48
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
49
- liger_kernel/transformers/monkey_patch.py,sha256=QpfNU7MmVDGlBWIZ2RLTSyh0vuZ-si7H37SL-qOliUs,64393
49
+ liger_kernel/transformers/monkey_patch.py,sha256=8Q84xxWA7ltgqgGRBxKxPPNeG7k5HYQfgaw1-HFnKGM,69287
50
50
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
51
51
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
52
52
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -55,28 +55,30 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
55
55
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
56
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
57
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
- liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
- liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
60
- liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
61
- liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
58
+ liger_kernel/transformers/model/gemma.py,sha256=nMUY2Iw7j6a-fOUqYBlfzIPznpKPKVa2DMBIZqCVfuI,10087
59
+ liger_kernel/transformers/model/gemma2.py,sha256=eulrUbh1DEMpMR6Lupx69kL-FeuRDP19mVoW1gc7keY,11194
60
+ liger_kernel/transformers/model/gemma3.py,sha256=wGSNqaLRRgIGQ_r9esyhDezm2SkAGZflopoWoWR-nYY,16226
61
+ liger_kernel/transformers/model/glm4.py,sha256=rtyMTtzgh_ncZ7DsfNxRJoUUm7xlDMKGzNqlxXjdAJk,5452
62
+ liger_kernel/transformers/model/llama.py,sha256=F8cvDAlf4NeKESdGEFXs8m3ue2F8i0h3aV2LricMqoM,10764
62
63
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
64
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
64
- liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
- liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
- liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
67
- liger_kernel/transformers/model/olmo2.py,sha256=rSzSALikEGkk0w3PLNQPrqg-ioN8TpWCXkAlg3LtCdI,5189
65
+ liger_kernel/transformers/model/mistral.py,sha256=1AcwJT9WOIpHkpu4Njs35ZryiGyW8ygERYmGqLz2Z4o,5752
66
+ liger_kernel/transformers/model/mixtral.py,sha256=URMzPLU1akf1H4hHXalCyfbVGUldRx8_jqdrZfM7Y-w,11773
67
+ liger_kernel/transformers/model/mllama.py,sha256=v_ayi6m4sC6AVKTrrLHF4W5HVaL86AYQNBqdWuTTOTw,11579
68
+ liger_kernel/transformers/model/olmo2.py,sha256=Kb6sGPsQS970GsYmWoT0DC2DFiXQ9Yjyxr8FRnT_8tQ,5460
68
69
  liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
69
- liger_kernel/transformers/model/phi3.py,sha256=ebITCrmwmb4z66CbSrZl1kD6BsP52IcSAR8uwUTp9nc,10455
70
- liger_kernel/transformers/model/qwen2.py,sha256=QaoTDrJv2wIuAM8QMoeWVvgNl0N5gHzIrew9QGG7kXc,9744
71
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
72
- liger_kernel/transformers/model/qwen2_vl.py,sha256=zo4O9fShNHYqSLrzLGqQYWSMtJI6UHaSY7zvMCYWyD8,9685
70
+ liger_kernel/transformers/model/phi3.py,sha256=TSeHK8H0mnS2esJaZI3lxmo5X3-Uwtd_TsrgvJRkm3s,10726
71
+ liger_kernel/transformers/model/qwen2.py,sha256=bEusb6vrVbagtSUHyntpi9j0x79IrZ1NP8iA5GR5Ryw,10015
72
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=oACIsTpg9_GdoSvekCyXLhJkuCpQEiFOTzKj7cjgi2E,9413
73
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=F6DeQ65wPtcpeQJZ9a3SJZKkQ-e24SRLdYUgC-_jT-k,9809
74
+ liger_kernel/transformers/model/qwen3.py,sha256=JdIeh0fvDLdGs8nk4_eHrovHCNa09VG15D4aa0X0mwI,5084
73
75
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
74
76
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
77
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
78
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel-0.5.8.dist-info/licenses/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel-0.5.8.dist-info/licenses/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
79
- liger_kernel-0.5.8.dist-info/METADATA,sha256=FAr_rRImlE1GETlKdEpEmRKA2Y9UzWbLKDmLWidJqeg,23340
80
- liger_kernel-0.5.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
81
- liger_kernel-0.5.8.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel-0.5.8.dist-info/RECORD,,
79
+ liger_kernel-0.5.9.dist-info/licenses/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
80
+ liger_kernel-0.5.9.dist-info/licenses/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
81
+ liger_kernel-0.5.9.dist-info/METADATA,sha256=Wq3nqeBFdqmOj8uiy7S4ZEL4xA88DVb0ad2b9KDn-qI,23627
82
+ liger_kernel-0.5.9.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
83
+ liger_kernel-0.5.9.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
84
+ liger_kernel-0.5.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5