liger-kernel-nightly 0.6.4.dev20251202094519__py3-none-any.whl → 0.6.4.dev20251206103502__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/rms_norm.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/monkey_patch.py +75 -0
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/METADATA +2 -1
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/RECORD +10 -9
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202094519.dist-info → liger_kernel_nightly-0.6.4.dev20251206103502.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -351,7 +351,7 @@ def _block_rms_norm_backward_kernel(
|
|
|
351
351
|
|
|
352
352
|
# calculate the gradient of W
|
|
353
353
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
354
|
-
|
|
354
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
355
355
|
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
356
356
|
else:
|
|
357
357
|
# here X_row is already in fp32 (see previous if block)
|
|
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
|
|
|
41
41
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
42
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
|
|
43
43
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
|
|
44
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
|
|
44
45
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
45
46
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
|
|
46
47
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
|
|
@@ -110,6 +111,7 @@ def __getattr__(name: str):
|
|
|
110
111
|
"apply_liger_kernel_to_glm4",
|
|
111
112
|
"apply_liger_kernel_to_glm4v",
|
|
112
113
|
"apply_liger_kernel_to_glm4v_moe",
|
|
114
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
113
115
|
"apply_liger_kernel_to_granite",
|
|
114
116
|
"apply_liger_kernel_to_internvl",
|
|
115
117
|
"apply_liger_kernel_to_llama",
|
|
@@ -187,6 +189,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
187
189
|
"apply_liger_kernel_to_glm4",
|
|
188
190
|
"apply_liger_kernel_to_glm4v",
|
|
189
191
|
"apply_liger_kernel_to_glm4v_moe",
|
|
192
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
190
193
|
"apply_liger_kernel_to_granite",
|
|
191
194
|
"apply_liger_kernel_to_internvl",
|
|
192
195
|
"apply_liger_kernel_to_llama",
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
9
|
+
|
|
10
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def lce_forward(
|
|
16
|
+
self,
|
|
17
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
18
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
19
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
21
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
22
|
+
labels: Optional[torch.LongTensor] = None,
|
|
23
|
+
use_cache: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
output_router_logits: Optional[bool] = None,
|
|
27
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
28
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> LigerMoeCausalLMOutputWithPast:
|
|
32
|
+
r"""
|
|
33
|
+
Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
37
|
+
Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers.
|
|
38
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
40
|
+
- 1 for tokens that are **not masked**,
|
|
41
|
+
- 0 for tokens that are **masked**.
|
|
42
|
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
43
|
+
Indices of positions of each input sequence tokens in the position embeddings.
|
|
44
|
+
past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*):
|
|
45
|
+
Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
|
|
46
|
+
sequential decoding. See `past_key_values` input for more details.
|
|
47
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
48
|
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
49
|
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
50
|
+
than the model's internal embedding lookup matrix.
|
|
51
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
52
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
53
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
54
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
55
|
+
use_cache (`bool`, *optional*):
|
|
56
|
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
57
|
+
(see `past_key_values`).
|
|
58
|
+
output_attentions (`bool`, *optional*):
|
|
59
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
60
|
+
tensors for more detail.
|
|
61
|
+
output_hidden_states (`bool`, *optional*):
|
|
62
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
63
|
+
more detail.
|
|
64
|
+
output_router_logits (`bool`, *optional*):
|
|
65
|
+
Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors
|
|
66
|
+
for more detail.
|
|
67
|
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
68
|
+
Indices depicting the position of the input sequence tokens in the sequence.
|
|
69
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
|
70
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
71
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
72
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
73
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
74
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
75
|
+
skip_logits (`bool`, *optional*):
|
|
76
|
+
Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training
|
|
77
|
+
when labels are provided (to save memory), and `False` during inference.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
`LigerMoeCausalLMOutputWithPast`: An output object containing:
|
|
81
|
+
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
82
|
+
Language modeling loss (for next-token prediction), including the auxiliary load balancing loss.
|
|
83
|
+
- aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
|
|
84
|
+
Auxiliary load balancing loss for the sparse MoE modules.
|
|
85
|
+
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
|
|
86
|
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
87
|
+
Note: logits are `None` during training when `skip_logits=True` to save memory.
|
|
88
|
+
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed):
|
|
89
|
+
Cached key and value projection states for faster sequential decoding.
|
|
90
|
+
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
|
91
|
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape
|
|
92
|
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer.
|
|
93
|
+
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
94
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
95
|
+
sequence_length)`. Attentions weights after the attention softmax.
|
|
96
|
+
- router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`):
|
|
97
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
|
|
98
|
+
Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss.
|
|
99
|
+
- token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
|
|
100
|
+
Token-level prediction accuracy.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
>>> from transformers import AutoTokenizer, GptOssForCausalLM
|
|
106
|
+
>>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
|
|
107
|
+
|
|
108
|
+
>>> # Apply Liger Kernel patches for optimized performance
|
|
109
|
+
>>> apply_liger_kernel_to_gpt_oss()
|
|
110
|
+
|
|
111
|
+
>>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b")
|
|
112
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
|
|
113
|
+
|
|
114
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
115
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
116
|
+
|
|
117
|
+
>>> # Inference: Forward pass returns logits
|
|
118
|
+
>>> outputs = model(**inputs)
|
|
119
|
+
>>> outputs.logits.shape
|
|
120
|
+
torch.Size([1, 12, 201088])
|
|
121
|
+
|
|
122
|
+
>>> # Get next token prediction
|
|
123
|
+
>>> next_token_logits = outputs.logits[:, -1, :]
|
|
124
|
+
>>> predicted_token_id = next_token_logits.argmax(dim=-1)
|
|
125
|
+
|
|
126
|
+
>>> # Training: Forward pass with labels returns loss
|
|
127
|
+
>>> labels = inputs.input_ids.clone()
|
|
128
|
+
>>> outputs = model(**inputs, labels=labels)
|
|
129
|
+
>>> outputs.loss
|
|
130
|
+
tensor(2.6454)
|
|
131
|
+
```"""
|
|
132
|
+
|
|
133
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
134
|
+
output_router_logits = (
|
|
135
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
output_hidden_states = (
|
|
139
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
143
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
144
|
+
input_ids=input_ids,
|
|
145
|
+
attention_mask=attention_mask,
|
|
146
|
+
position_ids=position_ids,
|
|
147
|
+
past_key_values=past_key_values,
|
|
148
|
+
inputs_embeds=inputs_embeds,
|
|
149
|
+
use_cache=use_cache,
|
|
150
|
+
output_attentions=output_attentions,
|
|
151
|
+
output_hidden_states=output_hidden_states,
|
|
152
|
+
output_router_logits=output_router_logits,
|
|
153
|
+
cache_position=cache_position,
|
|
154
|
+
**kwargs,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
hidden_states = outputs.last_hidden_state
|
|
158
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
159
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
160
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
161
|
+
|
|
162
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
163
|
+
logits = None
|
|
164
|
+
loss = None
|
|
165
|
+
token_accuracy = None
|
|
166
|
+
|
|
167
|
+
if skip_logits is None:
|
|
168
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
169
|
+
|
|
170
|
+
if skip_logits:
|
|
171
|
+
result = LigerForCausalLMLoss(
|
|
172
|
+
hidden_states=kept_hidden_states,
|
|
173
|
+
lm_head_weight=self.lm_head.weight,
|
|
174
|
+
labels=labels,
|
|
175
|
+
shift_labels=shift_labels,
|
|
176
|
+
hidden_size=self.config.hidden_size,
|
|
177
|
+
**kwargs,
|
|
178
|
+
)
|
|
179
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
180
|
+
else: # if in inference model materialize logits
|
|
181
|
+
logits = self.lm_head(kept_hidden_states)
|
|
182
|
+
if labels is not None or shift_labels is not None:
|
|
183
|
+
loss = self.loss_function(
|
|
184
|
+
logits=logits,
|
|
185
|
+
labels=labels,
|
|
186
|
+
shift_labels=shift_labels,
|
|
187
|
+
vocab_size=self.vocab_size,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
aux_loss = None
|
|
192
|
+
if output_router_logits:
|
|
193
|
+
aux_loss = load_balancing_loss_func(
|
|
194
|
+
outputs.router_logits,
|
|
195
|
+
self.num_experts,
|
|
196
|
+
self.num_experts_per_tok,
|
|
197
|
+
attention_mask,
|
|
198
|
+
)
|
|
199
|
+
if labels is not None:
|
|
200
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
201
|
+
|
|
202
|
+
return LigerMoeCausalLMOutputWithPast(
|
|
203
|
+
loss=loss,
|
|
204
|
+
aux_loss=aux_loss,
|
|
205
|
+
logits=logits,
|
|
206
|
+
past_key_values=outputs.past_key_values,
|
|
207
|
+
hidden_states=outputs.hidden_states,
|
|
208
|
+
attentions=outputs.attentions,
|
|
209
|
+
router_logits=outputs.router_logits,
|
|
210
|
+
token_accuracy=token_accuracy,
|
|
211
|
+
)
|
|
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
|
|
|
20
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
23
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
24
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
25
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -1459,6 +1460,79 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1459
1460
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
1461
|
|
|
1461
1462
|
|
|
1463
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1464
|
+
rope: bool = True,
|
|
1465
|
+
cross_entropy: bool = False,
|
|
1466
|
+
fused_linear_cross_entropy: bool = True,
|
|
1467
|
+
rms_norm: bool = True,
|
|
1468
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1469
|
+
model: PreTrainedModel = None,
|
|
1470
|
+
) -> None:
|
|
1471
|
+
"""
|
|
1472
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1473
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1474
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1475
|
+
implementation with clamping and MXFP4 quantization.
|
|
1476
|
+
|
|
1477
|
+
Args:
|
|
1478
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1479
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1480
|
+
fused_linear_cross_entropy (bool):
|
|
1481
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1482
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1483
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1484
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1485
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1486
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1487
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1488
|
+
loaded. Default is None.
|
|
1489
|
+
"""
|
|
1490
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1491
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1492
|
+
return
|
|
1493
|
+
|
|
1494
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1495
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1499
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1500
|
+
|
|
1501
|
+
if rope:
|
|
1502
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1503
|
+
|
|
1504
|
+
if rms_norm:
|
|
1505
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1506
|
+
|
|
1507
|
+
if cross_entropy:
|
|
1508
|
+
from transformers.loss.loss_utils import nn
|
|
1509
|
+
|
|
1510
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1511
|
+
|
|
1512
|
+
if fused_linear_cross_entropy:
|
|
1513
|
+
if model is not None:
|
|
1514
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1515
|
+
else:
|
|
1516
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1517
|
+
|
|
1518
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1519
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1520
|
+
|
|
1521
|
+
if model is not None:
|
|
1522
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1523
|
+
# instance variables that reference already-instantiated modules
|
|
1524
|
+
|
|
1525
|
+
# get the base model from the model instance
|
|
1526
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1527
|
+
|
|
1528
|
+
if rms_norm:
|
|
1529
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1530
|
+
for decoder_layer in base_model.layers:
|
|
1531
|
+
if rms_norm:
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1534
|
+
|
|
1535
|
+
|
|
1462
1536
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1463
1537
|
rope: bool = True,
|
|
1464
1538
|
cross_entropy: bool = False,
|
|
@@ -2752,6 +2826,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2752
2826
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2753
2827
|
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2754
2828
|
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2829
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2755
2830
|
"internvl": apply_liger_kernel_to_internvl,
|
|
2756
2831
|
"llama": apply_liger_kernel_to_llama,
|
|
2757
2832
|
"llama4_text": apply_liger_kernel_to_llama4,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.6.4.
|
|
3
|
+
Version: 0.6.4.dev20251206103502
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -312,6 +312,7 @@ loss.backward()
|
|
|
312
312
|
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
313
313
|
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
314
314
|
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
315
|
+
| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
315
316
|
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
316
317
|
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
317
318
|
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -33,7 +33,7 @@ liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd
|
|
|
33
33
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
34
34
|
liger_kernel/ops/poly_norm.py,sha256=5IdJEZnbbhblkL_X8UhSD4A2CooQbOAZJw8nAekWNs4,11372
|
|
35
35
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
36
|
-
liger_kernel/ops/rms_norm.py,sha256=
|
|
36
|
+
liger_kernel/ops/rms_norm.py,sha256=owWgM1jE5aP4clshCNWiulnemHPzR72D9QN2kc3eoe0,19220
|
|
37
37
|
liger_kernel/ops/rope.py,sha256=v-7JHRrv-5ImoROkpKfl30WwWI4qTa2tAl7zQeB4ml4,8956
|
|
38
38
|
liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
|
|
39
39
|
liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
|
|
@@ -43,7 +43,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
|
43
43
|
liger_kernel/ops/utils.py,sha256=kYp84AOA7D9PYrvBUSrNsfQIt8elr_uA9OxCkbfiUFA,3980
|
|
44
44
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
45
45
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
46
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
46
|
+
liger_kernel/transformers/__init__.py,sha256=4sqcDbOZ_JtS9Ag-7oyuhq5jN298GyzjJFu9J-DyyZQ,10872
|
|
47
47
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
48
48
|
liger_kernel/transformers/cross_entropy.py,sha256=DMtHkKrVJDSsels7KgGQJqrXkEAd6Zopcdr-5oRmQgE,2010
|
|
49
49
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
|
@@ -60,7 +60,7 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
|
|
|
60
60
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
|
61
61
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
|
62
62
|
liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
|
|
63
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
63
|
+
liger_kernel/transformers/monkey_patch.py,sha256=0ER5BjQXcIKwgL2e7ji3_DIm1DaJzevzo53aXSe2YJU,135862
|
|
64
64
|
liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
|
|
65
65
|
liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
|
|
66
66
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
|
@@ -82,6 +82,7 @@ liger_kernel/transformers/model/gemma3.py,sha256=ZUrFCc-pfF8jYHV0HsptBr98hx6p2q9
|
|
|
82
82
|
liger_kernel/transformers/model/glm4.py,sha256=bSp22iPIjsli4-c_usUOsyh1Bs2gIK8X6ynS0azseUs,5900
|
|
83
83
|
liger_kernel/transformers/model/glm4v.py,sha256=dd-BQpccDCp1SbIxcJ5rG8xcwYQK3KOv1Tgm9TGnZc4,6594
|
|
84
84
|
liger_kernel/transformers/model/glm4v_moe.py,sha256=zKhMdOOrRhlrvCSFaeVYfddL1ubpY8edEO91TN81n98,7135
|
|
85
|
+
liger_kernel/transformers/model/gpt_oss.py,sha256=8jEAQQNEXgVA-yuvEjKkBQvCvZy0E9ns-O9BPlajXXU,11197
|
|
85
86
|
liger_kernel/transformers/model/hunyuan_v1.py,sha256=MJvP9xkUFePIV0HLETJM4YPbVCEPkAE1ZI5Jxyiebh0,5731
|
|
86
87
|
liger_kernel/transformers/model/internvl.py,sha256=OOutracs9qrPHSU7FVYar08yinvGrHQVPvo39JEws6w,6473
|
|
87
88
|
liger_kernel/transformers/model/llama.py,sha256=kqZeONzwTBzudoChlKMzq1w23BtYGbxWZC1l1V__JTw,13410
|
|
@@ -110,9 +111,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
110
111
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
111
112
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
112
113
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
113
|
-
liger_kernel_nightly-0.6.4.
|
|
114
|
-
liger_kernel_nightly-0.6.4.
|
|
115
|
-
liger_kernel_nightly-0.6.4.
|
|
116
|
-
liger_kernel_nightly-0.6.4.
|
|
117
|
-
liger_kernel_nightly-0.6.4.
|
|
118
|
-
liger_kernel_nightly-0.6.4.
|
|
114
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
115
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/METADATA,sha256=9fKlKAtH1HWCDW5wysDkZA-z5cqMjTz3YDhacJ_dhn8,25375
|
|
116
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
117
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
118
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
119
|
+
liger_kernel_nightly-0.6.4.dev20251206103502.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|