liger-kernel-nightly 0.6.2.dev20251016055812__py3-none-any.whl → 0.6.2.dev20251019095057__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
55
55
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
56
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
57
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
58
59
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
59
60
 
60
61
 
@@ -117,6 +118,7 @@ def __getattr__(name: str):
117
118
  "apply_liger_kernel_to_qwen2_vl",
118
119
  "apply_liger_kernel_to_qwen3",
119
120
  "apply_liger_kernel_to_qwen3_moe",
121
+ "apply_liger_kernel_to_qwen3_next",
120
122
  "apply_liger_kernel_to_smollm3",
121
123
  }
122
124
 
@@ -185,6 +187,7 @@ if _TRANSFORMERS_AVAILABLE:
185
187
  "apply_liger_kernel_to_qwen2_vl",
186
188
  "apply_liger_kernel_to_qwen3",
187
189
  "apply_liger_kernel_to_qwen3_moe",
190
+ "apply_liger_kernel_to_qwen3_next",
188
191
  "apply_liger_kernel_to_smollm3",
189
192
  ]
190
193
  )
@@ -0,0 +1,134 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
9
+ from transformers.modeling_outputs import MoeModelOutputWithPast
10
+
11
+ if TYPE_CHECKING:
12
+ from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func
13
+
14
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+
16
+
17
+ def lce_forward(
18
+ self,
19
+ input_ids: Optional[torch.LongTensor] = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ output_router_logits: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> MoeCausalLMOutputWithPast:
34
+ r"""
35
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
36
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
37
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
38
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
39
+
40
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
41
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
42
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
43
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
44
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
45
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
+
47
+ Returns:
48
+
49
+ Example:
50
+
51
+ ```python
52
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
55
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
56
+
57
+ >>> prompt = "Give me a short introduction to large language model."
58
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
59
+
60
+ >>> # Generate
61
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
62
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
64
+ ```"""
65
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
66
+ output_router_logits = (
67
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
68
+ )
69
+
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+
74
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
75
+ outputs: MoeModelOutputWithPast = self.model(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ output_router_logits=output_router_logits,
85
+ cache_position=cache_position,
86
+ **kwargs,
87
+ )
88
+
89
+ hidden_states = outputs.last_hidden_state
90
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
91
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
92
+ kept_hidden_states = hidden_states[:, slice_indices, :]
93
+
94
+ shift_labels = kwargs.pop("shift_labels", None)
95
+ logits = None
96
+ loss = None
97
+
98
+ if skip_logits is None:
99
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
100
+
101
+ if skip_logits:
102
+ loss = LigerForCausalLMLoss(
103
+ hidden_states=kept_hidden_states,
104
+ lm_head_weight=self.lm_head.weight,
105
+ labels=labels,
106
+ shift_labels=shift_labels,
107
+ hidden_size=self.config.hidden_size,
108
+ **kwargs,
109
+ )
110
+ else: # if in inference model materialize logits
111
+ logits = self.lm_head(kept_hidden_states)
112
+ if labels is not None or shift_labels is not None:
113
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
114
+
115
+ aux_loss = None
116
+ if output_router_logits:
117
+ aux_loss = load_balancing_loss_func(
118
+ outputs.router_logits,
119
+ self.num_experts,
120
+ self.num_experts_per_tok,
121
+ attention_mask,
122
+ )
123
+ if labels is not None:
124
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
125
+
126
+ return MoeCausalLMOutputWithPast(
127
+ loss=loss,
128
+ aux_loss=aux_loss,
129
+ logits=logits,
130
+ past_key_values=outputs.past_key_values,
131
+ hidden_states=outputs.hidden_states,
132
+ attentions=outputs.attentions,
133
+ router_logits=outputs.router_logits,
134
+ )
@@ -2180,6 +2180,97 @@ def apply_liger_kernel_to_falcon_h1(
2180
2180
  _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2181
2181
 
2182
2182
 
2183
+ def apply_liger_kernel_to_qwen3_next(
2184
+ rope: bool = False,
2185
+ cross_entropy: bool = False,
2186
+ fused_linear_cross_entropy: bool = True,
2187
+ rms_norm: bool = True,
2188
+ swiglu: bool = True,
2189
+ model: PreTrainedModel = None,
2190
+ ) -> None:
2191
+ """
2192
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2193
+
2194
+ Args:
2195
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2196
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2197
+ fused_linear_cross_entropy (bool):
2198
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2199
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2200
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2201
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2202
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2203
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2204
+ loaded. Default is None.
2205
+ """
2206
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2207
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2208
+ )
2209
+
2210
+ from transformers.models.qwen3_next import modeling_qwen3_next
2211
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2212
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2213
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2214
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2215
+
2216
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2217
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2218
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2219
+
2220
+ if rope:
2221
+ # It might enocunter nan issue
2222
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2223
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2224
+ if rms_norm:
2225
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2226
+ if cross_entropy:
2227
+ from transformers.loss.loss_utils import nn
2228
+
2229
+ nn.functional.cross_entropy = liger_cross_entropy
2230
+ if fused_linear_cross_entropy:
2231
+ if model is not None:
2232
+ if isinstance(model, Qwen3NextForCausalLM):
2233
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2234
+ else:
2235
+ raise TypeError(
2236
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2237
+ )
2238
+ else:
2239
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2240
+ if swiglu:
2241
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2242
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2243
+
2244
+ if model is not None:
2245
+ # The model instance already exists, so we need to additionally patch the
2246
+ # instance variables that reference already-instantiated modules
2247
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2248
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2249
+ else:
2250
+ raise TypeError(
2251
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2252
+ )
2253
+
2254
+ if rms_norm:
2255
+ _patch_rms_norm_module(base_model.norm)
2256
+
2257
+ for decoder_layer in base_model.layers:
2258
+ if rms_norm:
2259
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2260
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2261
+
2262
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2263
+ if swiglu:
2264
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2265
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2266
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2267
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2268
+ experts = getattr(decoder_layer.mlp, "experts", None)
2269
+ if experts is not None:
2270
+ for expert in experts:
2271
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2272
+
2273
+
2183
2274
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2184
2275
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2185
2276
  "gemma": apply_liger_kernel_to_gemma,
@@ -2207,6 +2298,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2207
2298
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
2208
2299
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2209
2300
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2301
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2210
2302
  "smollm3": apply_liger_kernel_to_smollm3,
2211
2303
  "phi3": apply_liger_kernel_to_phi3,
2212
2304
  "paligemma": apply_liger_kernel_to_paligemma,
@@ -77,3 +77,10 @@ class LigerRMSNormForGlm4(LigerRMSNorm):
77
77
  self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
78
78
  ):
79
79
  super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
80
+
81
+
82
+ class LigerRMSNormForQwen3Next(LigerRMSNorm):
83
+ def __init__(
84
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
85
+ ):
86
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251016055812
3
+ Version: 0.6.2.dev20251019095057
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -42,7 +42,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
42
42
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
43
43
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
44
44
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
45
- liger_kernel/transformers/__init__.py,sha256=d0H4knUp93iR3OPR3lpYriZYCvC-w_i2cDTYtcgfhzo,9107
45
+ liger_kernel/transformers/__init__.py,sha256=JovUTGIMKlQGiuoHIICmJqwBWUc9lkdZFNHBToR8bpY,9301
46
46
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
47
47
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
48
48
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -59,11 +59,11 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
59
59
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
60
60
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
61
61
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
62
- liger_kernel/transformers/monkey_patch.py,sha256=TUmx8aY0lonyThcATirRBdSs7uItVvnBggohjBItBuQ,106060
62
+ liger_kernel/transformers/monkey_patch.py,sha256=bD9m04L2EYPzzkA0yEqpw7uR3ktbtwG5nSE-JaT54xc,110694
63
63
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
64
64
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
65
65
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
66
- liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
66
+ liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyjBN9b3FYI,3004
67
67
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
68
68
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
69
69
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
@@ -96,14 +96,15 @@ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=Ea3zvL1FJfjlaerpeXCq-1zmorr
96
96
  liger_kernel/transformers/model/qwen2_vl.py,sha256=ZeasFPGs-bxm2Y_E15mo0YNx5wwtKYDV-bjVKjkLPBk,6018
97
97
  liger_kernel/transformers/model/qwen3.py,sha256=Q2aOg5erPrgVgRcqJm8sefLSDtvU1AD5B7aJnP7mRMM,4956
98
98
  liger_kernel/transformers/model/qwen3_moe.py,sha256=1CwTMCNFDYsjGoa_aHFBagtC5HuJTV-s0__5UvcjD3A,5686
99
+ liger_kernel/transformers/model/qwen3_next.py,sha256=7To7azriAogxeE7oEvByKztH9154dnDiDVNHHm7PZK4,5632
99
100
  liger_kernel/transformers/model/smollm3.py,sha256=0KWVkDtXbjsBKhJnaquV6vUUYyLtfmNwYH0sxJt-qTk,7667
100
101
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
101
102
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
102
103
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
103
104
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
104
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
105
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/METADATA,sha256=0T7yuosaQopminlzrQ4Z2ZyY7Lm_Dst67jQScbOIlHU,24777
106
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
107
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
108
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
109
- liger_kernel_nightly-0.6.2.dev20251016055812.dist-info/RECORD,,
105
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
106
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/METADATA,sha256=cadXEZlvX80i-iJmLtyatXvjID2o2t2EWt0vlDWV9ls,24777
107
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
108
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
109
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
110
+ liger_kernel_nightly-0.6.2.dev20251019095057.dist-info/RECORD,,