liger-kernel-nightly 0.6.3.dev20251121195543__py3-none-any.whl → 0.6.3.dev20251121202601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
- loss_type="bnpo",
35
+ loss_type="dapo",
36
36
  max_completion_length=None,
37
37
  importance_sampling_level="token",
38
38
  temperature=1.0,
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
60
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
61
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
62
62
  beta: Weight for the KL penalty
63
- loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
64
  max_completion_length: Maximum completion length required for "dr_grpo"
65
65
  temperature: Temperature for the logits
66
66
  compiled: Whether to use torch compile
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
244
244
 
245
245
  return loss_acc, tuple(final_metrics)
246
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
247
262
  @staticmethod
248
263
  def _compute_chunk_loss(
249
264
  input_chunk,
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
276
  epsilon_low=0.2,
262
277
  epsilon_high=0.2,
263
278
  beta=0.04,
264
- loss_type="bnpo",
279
+ loss_type="dapo",
265
280
  max_completion_length=None,
266
281
  importance_sampling_level="token",
267
282
  temperature=1.0,
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
341
356
  None, # grad_epsilon_low
342
357
  None, # grad_epsilon_high
343
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
344
362
  None, # grad_temperature
345
363
  None, # grad_compiled
346
364
  None, # grad_use_ref_model
347
365
  None, # grad_chunk_size
348
- None, # grad_loss_type
349
- None, # grad_max_completion_length
350
366
  )
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
29
29
  epsilon_low=0.2,
30
30
  epsilon_high=0.2,
31
31
  beta=0.04,
32
- loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
34
  importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
35
35
  **kwargs,
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
94
94
  if max_completion_length is None:
95
95
  raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
96
  loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
97
100
  else:
98
101
  raise ValueError(f"Unknown loss type: {loss_type}")
99
102
 
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
135
138
  beta=0.04,
136
139
  epsilon_low=0.2,
137
140
  epsilon_high=0.2,
138
- loss_type="bnpo",
141
+ loss_type="dapo",
139
142
  max_completion_length=None,
140
143
  importance_sampling_level="token",
141
144
  temperature=1.0,
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
157
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
158
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
159
162
  beta (float): Weight for the KL penalty
160
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
161
164
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
165
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
163
166
  temperature (float): Temperature for the logits
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
235
238
  chunk_size: int = 1,
236
239
  epsilon_low: float = 0.2,
237
240
  epsilon_high: float = 0.2,
238
- loss_type: str = "bnpo",
241
+ loss_type: str = "dapo",
239
242
  max_completion_length: Optional[int] = None,
240
243
  importance_sampling_level: str = "token",
241
244
  temperature: float = 1.0,
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
248
251
  chunk_size (int): Size of chunks for processing.
249
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
250
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
251
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
252
255
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
256
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
254
257
  temperature (float): Temperature for the logits.
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
52
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
53
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
54
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
55
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
56
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
57
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -118,6 +119,7 @@ def __getattr__(name: str):
118
119
  "apply_liger_kernel_to_mixtral",
119
120
  "apply_liger_kernel_to_mllama",
120
121
  "apply_liger_kernel_to_olmo2",
122
+ "apply_liger_kernel_to_olmo3",
121
123
  "apply_liger_kernel_to_paligemma",
122
124
  "apply_liger_kernel_to_phi3",
123
125
  "apply_liger_kernel_to_qwen2",
@@ -194,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
194
196
  "apply_liger_kernel_to_mixtral",
195
197
  "apply_liger_kernel_to_mllama",
196
198
  "apply_liger_kernel_to_olmo2",
199
+ "apply_liger_kernel_to_olmo3",
197
200
  "apply_liger_kernel_to_paligemma",
198
201
  "apply_liger_kernel_to_phi3",
199
202
  "apply_liger_kernel_to_qwen2",
@@ -1,3 +1,6 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
1
4
  from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
5
 
3
6
 
@@ -13,12 +16,20 @@ def triton_grpo_loss(
13
16
  eps_low=0.2,
14
17
  eps_high=0.4,
15
18
  inplace=True,
19
+ loss_type="dapo",
20
+ max_completion_length=None,
21
+ importance_sampling_level="token",
22
+ reduce=False,
16
23
  ):
17
24
  assert logits is not None and completion_ids is not None and advantages is not None, (
18
25
  "must provide logits、completion_ids and advantages"
19
26
  )
27
+ if importance_sampling_level != "token":
28
+ raise ValueError(
29
+ f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
30
+ )
20
31
 
21
- return GrpoLossFunction.apply(
32
+ per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
22
33
  logits,
23
34
  old_logp,
24
35
  ref_logp,
@@ -31,6 +42,50 @@ def triton_grpo_loss(
31
42
  eps_high,
32
43
  inplace,
33
44
  )
45
+ if not reduce:
46
+ return per_token_loss, per_token_kl, is_clipped
47
+
48
+ loss = _reduce_grpo_loss(
49
+ per_token_loss,
50
+ completion_mask,
51
+ loss_type=loss_type,
52
+ max_completion_length=max_completion_length,
53
+ )
54
+
55
+ metrics = []
56
+ if beta != 0.0 and per_token_kl is not None:
57
+ metrics.append(_masked_mean(per_token_kl, completion_mask))
58
+ metrics.append(_masked_mean(is_clipped.float(), completion_mask))
59
+ return loss, metrics
60
+
61
+
62
+ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
63
+ mask = completion_mask
64
+ if mask is None:
65
+ mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
66
+ mask = mask.to(per_token_loss.dtype)
67
+
68
+ if loss_type == "grpo":
69
+ per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
70
+ return per_seq.mean()
71
+ if loss_type == "bnpo":
72
+ return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
73
+ if loss_type == "dr_grpo":
74
+ if max_completion_length is None:
75
+ raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
76
+ batch = per_token_loss.shape[0]
77
+ return (per_token_loss * mask).sum() / (batch * max_completion_length)
78
+ if loss_type == "dapo":
79
+ normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
80
+ return (per_token_loss * mask).sum() / normalizer
81
+ raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
82
+
83
+
84
+ def _masked_mean(values, mask):
85
+ if mask is None:
86
+ mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
87
+ mask = mask.to(values.dtype)
88
+ return (values * mask).sum() / mask.sum().clamp(min=1.0)
34
89
 
35
90
 
36
91
  # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
@@ -0,0 +1,142 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+
15
+
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ return_dict: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+
41
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
42
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
43
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
44
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
45
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
46
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
47
+
48
+ Returns:
49
+
50
+ Example:
51
+
52
+ ```python
53
+ >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
54
+
55
+ >>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
57
+
58
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
59
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
60
+
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
65
+ ```
66
+ """
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+
73
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
74
+ outputs: BaseModelOutputWithPast = self.model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ cache_position=cache_position,
85
+ **kwargs,
86
+ )
87
+
88
+ hidden_states = outputs[0]
89
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
90
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
91
+ kept_hidden_states = hidden_states[:, slice_indices, :]
92
+
93
+ shift_labels = kwargs.pop("shift_labels", None)
94
+ logits = None
95
+ loss = None
96
+ token_accuracy = None
97
+
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ # Compute loss
106
+ if skip_logits:
107
+ result = LigerForCausalLMLoss(
108
+ hidden_states=kept_hidden_states,
109
+ lm_head_weight=self.lm_head.weight,
110
+ labels=labels,
111
+ shift_labels=shift_labels,
112
+ hidden_size=self.config.hidden_size,
113
+ **kwargs,
114
+ )
115
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
+
117
+ else:
118
+ logits = self.lm_head(kept_hidden_states)
119
+ if labels is not None or shift_labels is not None:
120
+ loss = self.loss_function(
121
+ logits=logits,
122
+ labels=labels,
123
+ shift_labels=shift_labels,
124
+ vocab_size=self.config.vocab_size,
125
+ **kwargs,
126
+ )
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ output = ((loss,) + output) if loss is not None else output
131
+ output = output + (token_accuracy,) if token_accuracy is not None else output
132
+ return output
133
+
134
+ # Return custom output class with token_accuracy field
135
+ return LigerCausalLMOutputWithPast(
136
+ loss=loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions,
141
+ token_accuracy=token_accuracy,
142
+ )
@@ -1928,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
1928
1928
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1929
1929
 
1930
1930
 
1931
+ def apply_liger_kernel_to_olmo3(
1932
+ rope: bool = True,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1944
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945
+ fused_linear_cross_entropy (bool):
1946
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
1951
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952
+ loaded. Default is None.
1953
+ """
1954
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1955
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1956
+ )
1957
+
1958
+ from transformers.models.olmo3 import modeling_olmo3
1959
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
1960
+
1961
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
1962
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1963
+
1964
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
1965
+ if rope:
1966
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
1967
+ if rms_norm:
1968
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
1969
+ if swiglu:
1970
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
1971
+ if cross_entropy:
1972
+ from transformers.loss.loss_utils import nn
1973
+
1974
+ nn.functional.cross_entropy = liger_cross_entropy
1975
+ if fused_linear_cross_entropy:
1976
+ if model is not None:
1977
+ model.forward = MethodType(olmo3_lce_forward, model)
1978
+ else:
1979
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
1980
+
1981
+ if model is not None:
1982
+ # The model instance already exists, so we need to additionally patch the
1983
+ # instance variables that reference already-instantiated modules
1984
+
1985
+ # get the base model from the model instance
1986
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
1987
+
1988
+ if rms_norm:
1989
+ _patch_rms_norm_module(base_model.norm)
1990
+
1991
+ for decoder_layer in base_model.layers:
1992
+ if swiglu:
1993
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1994
+ if rms_norm:
1995
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1997
+
1998
+
1931
1999
  def apply_liger_kernel_to_glm4(
1932
2000
  rope: bool = False,
1933
2001
  cross_entropy: bool = False,
@@ -2589,7 +2657,7 @@ def apply_liger_kernel_to_hunyuan_v1_dense(
2589
2657
  from transformers.loss.loss_utils import nn
2590
2658
 
2591
2659
  nn.functional.cross_entropy = liger_cross_entropy
2592
-
2660
+
2593
2661
  if fused_linear_cross_entropy:
2594
2662
  if model is not None:
2595
2663
  model.forward = MethodType(hunyuan_v1_lce_forward, model)
@@ -2695,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2695
2763
  "mistral": apply_liger_kernel_to_mistral,
2696
2764
  "mixtral": apply_liger_kernel_to_mixtral,
2697
2765
  "olmo2": apply_liger_kernel_to_olmo2,
2766
+ "olmo3": apply_liger_kernel_to_olmo3,
2698
2767
  "qwen2": apply_liger_kernel_to_qwen2,
2699
2768
  "qwen3": apply_liger_kernel_to_qwen3,
2700
2769
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -93,4 +93,4 @@ class LigerHunyuanV1SwiGLUMLP(nn.Module):
93
93
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
94
94
 
95
95
  def forward(self, x):
96
- return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
96
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.3.dev20251121195543
3
+ Version: 0.6.3.dev20251121202601
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -310,6 +310,7 @@ loss.backward()
310
310
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
311
311
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
312
312
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
+ | Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
314
  | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
315
  | InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
316
  | HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -8,10 +8,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
8
8
  liger_kernel/chunked_loss/dpo_loss.py,sha256=I83khNs3QQjuhr8U3NIOAACkbse6DNiBV-TulPZ0lXw,9006
9
9
  liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
10
10
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=yRtolfFGfKB-SxGQQyF68GYXd11Zlvh1InLdGeWNFIE,12652
11
- liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=ZjpNP5VC-tXXIKb4AckkQ3iWWQeej-JoG4StJq3N0wg,13650
11
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=baU19PwqO1FTVxwlB-eyJv6gOLtL7baXGzSncYQ8Ktc,14296
12
12
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
13
13
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
14
- liger_kernel/chunked_loss/grpo_loss.py,sha256=SkZuKoW8K94UbWR-OtfopsQkuQ8tFOr_90AGR6_Mhes,12844
14
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=bmuZaNgqNbJ5pJGFDXWE-B4BGYF7xWVSN15UyCfuq_s,13079
15
15
  liger_kernel/chunked_loss/jsd_loss.py,sha256=G0RghPYYelyZ6DOEiwS8we9TT5MY2iHpiFqzZ2Xy87g,8038
16
16
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
17
17
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
@@ -25,7 +25,7 @@ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHu
25
25
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
26
26
  liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
27
27
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
28
- liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
28
+ liger_kernel/ops/grpo_loss.py,sha256=2SyOujtF9I3xiNo4wFf4s6MeiDotE_qeYfRWgj_bOBE,9573
29
29
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
30
30
  liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
31
31
  liger_kernel/ops/layer_norm.py,sha256=WmiORsIyufOhazmYZTPjeSc5Z-xTAYwXAKqUcCv_dlY,9807
@@ -43,7 +43,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
43
43
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
44
44
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
45
45
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
46
- liger_kernel/transformers/__init__.py,sha256=zApQL1sGf2GOo7gWHfQxJ9X6D7QFwAuds4cPJJ-H81Y,10508
46
+ liger_kernel/transformers/__init__.py,sha256=CgwhrY5cdx6OcRgR2ZZJbOIkLswQWPTr-BAaoxDNNOY,10687
47
47
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
48
48
  liger_kernel/transformers/cross_entropy.py,sha256=DMtHkKrVJDSsels7KgGQJqrXkEAd6Zopcdr-5oRmQgE,2010
49
49
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -55,12 +55,12 @@ liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJl
55
55
  liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
56
56
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
57
57
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
58
- liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
58
+ liger_kernel/transformers/grpo_loss.py,sha256=QS6Ycct1E2yMfqoHPBa2sUAu5cmweNPK_-Q_KJE8hb4,6098
59
59
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
60
60
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
61
61
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
62
62
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
63
- liger_kernel/transformers/monkey_patch.py,sha256=Th4XiYT2fRRo1YNOCLkLLiTgjEpCdidWOT8-ozxgFsE,129377
63
+ liger_kernel/transformers/monkey_patch.py,sha256=4LV6LSz_AAop6HWk1spZm1QigPN9nUDPJu9tK21-jIo,132446
64
64
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
65
65
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
66
66
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
@@ -68,7 +68,7 @@ liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyj
68
68
  liger_kernel/transformers/rope.py,sha256=VMlDZI6zss9mLaLcN5XCE_ktmYRwAi_Eh4TIgO6NrIQ,2361
69
69
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
70
70
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
71
- liger_kernel/transformers/swiglu.py,sha256=FLvxamjGru9N-ZelsccTvNn0CjUnId9ldiBrOnH-8QQ,4290
71
+ liger_kernel/transformers/swiglu.py,sha256=dRR69wDWSWfdjtnsTECyxQqWVo5QkdXdXm9SpSQ4Jvw,4291
72
72
  liger_kernel/transformers/tiled_mlp.py,sha256=J51-kpzwikDMMhT5bX-RZCKMaXBK6zZc1bhgRYTK5F0,4651
73
73
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
74
74
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
@@ -92,6 +92,7 @@ liger_kernel/transformers/model/mistral.py,sha256=OcwOzVDMwwDbVccVPv-AaocznzWwzL
92
92
  liger_kernel/transformers/model/mixtral.py,sha256=YcBDoTEJDgLFJ_RTo180DYGxR8D5Ad9-idumif7kCPE,12130
93
93
  liger_kernel/transformers/model/mllama.py,sha256=vAHwCm63sn4kpAY0rDGf_N0HR7KRTBVpBYDVTPOaZTg,12079
94
94
  liger_kernel/transformers/model/olmo2.py,sha256=-h2bUOeuPfY1MdShdRvq5_wFDHKP4PEimgIl0fL-BT4,5902
95
+ liger_kernel/transformers/model/olmo3.py,sha256=k2zYOlS8U_b5MwjdToB3tDRQ0bH_mWapVQqJcH8-qAo,6007
95
96
  liger_kernel/transformers/model/output_classes.py,sha256=0BGXVR4dYQpSHLkSqpRoXuHMryrceGSlTYRu6pvd8ZY,4542
96
97
  liger_kernel/transformers/model/paligemma.py,sha256=r0smHLADkEwfLS6d6ArWoSWEeLt2d_8pmgOO5F04b1o,20793
97
98
  liger_kernel/transformers/model/phi3.py,sha256=PT7Kw6yySg-7TsssWfi82eVMN3SWujCqzCqHigAdfeQ,4574
@@ -109,9 +110,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
109
110
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
110
111
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
111
112
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
112
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
113
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/METADATA,sha256=3jCmn8uhrcmwtmASbQwo4NBudeZ9NJOfj5RN5Xlylr0,25097
114
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
115
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
116
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
117
- liger_kernel_nightly-0.6.3.dev20251121195543.dist-info/RECORD,,
113
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
114
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/METADATA,sha256=g6I_79r5KT_3WD4b8mKW-bb_Q_Gx1O8pSZ2LwUOu-OA,25238
115
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
116
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
117
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
118
+ liger_kernel_nightly-0.6.3.dev20251121202601.dist-info/RECORD,,