liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
15
  from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
16
  from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
17
  from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
18
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
19
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20
21
  from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
@@ -23,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
23
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
24
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
25
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
26
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
27
30
 
28
31
  # Static-only imports for IDEs and type checkers
@@ -38,7 +41,10 @@ if TYPE_CHECKING:
38
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
39
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
40
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
41
45
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
42
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
43
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
44
50
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -47,6 +53,7 @@ if TYPE_CHECKING:
47
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
48
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
49
55
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
56
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
50
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
51
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
52
59
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -54,7 +61,11 @@ if TYPE_CHECKING:
54
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
55
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
56
63
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
65
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
66
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
57
67
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
68
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
58
69
 
59
70
 
60
71
  # Check if 'transformers' is installed
@@ -100,6 +111,7 @@ def __getattr__(name: str):
100
111
  "apply_liger_kernel_to_glm4",
101
112
  "apply_liger_kernel_to_glm4v",
102
113
  "apply_liger_kernel_to_glm4v_moe",
114
+ "apply_liger_kernel_to_gpt_oss",
103
115
  "apply_liger_kernel_to_granite",
104
116
  "apply_liger_kernel_to_internvl",
105
117
  "apply_liger_kernel_to_llama",
@@ -109,6 +121,7 @@ def __getattr__(name: str):
109
121
  "apply_liger_kernel_to_mixtral",
110
122
  "apply_liger_kernel_to_mllama",
111
123
  "apply_liger_kernel_to_olmo2",
124
+ "apply_liger_kernel_to_olmo3",
112
125
  "apply_liger_kernel_to_paligemma",
113
126
  "apply_liger_kernel_to_phi3",
114
127
  "apply_liger_kernel_to_qwen2",
@@ -116,7 +129,13 @@ def __getattr__(name: str):
116
129
  "apply_liger_kernel_to_qwen2_vl",
117
130
  "apply_liger_kernel_to_qwen3",
118
131
  "apply_liger_kernel_to_qwen3_moe",
132
+ "apply_liger_kernel_to_qwen3_next",
133
+ "apply_liger_kernel_to_qwen3_vl",
134
+ "apply_liger_kernel_to_qwen3_vl_moe",
119
135
  "apply_liger_kernel_to_smollm3",
136
+ "apply_liger_kernel_to_smolvlm",
137
+ "apply_liger_kernel_to_hunyuan_v1_dense",
138
+ "apply_liger_kernel_to_hunyuan_v1_moe",
120
139
  }
121
140
 
122
141
  if name in monkey_patch_symbols:
@@ -137,6 +156,7 @@ __all__ = [
137
156
  "LigerJSD",
138
157
  "LigerLayerNorm",
139
158
  "LigerFusedAddRMSNorm",
159
+ "LigerPolyNorm",
140
160
  "LigerRMSNorm",
141
161
  "liger_rotary_pos_emb",
142
162
  "liger_llama4_text_rotary_pos_emb",
@@ -145,6 +165,8 @@ __all__ = [
145
165
  "LigerPhi3SwiGLUMLP",
146
166
  "LigerQwen3MoeSwiGLUMLP",
147
167
  "LigerSwiGLUMLP",
168
+ "LigerTiledGEGLUMLP",
169
+ "LigerTiledSwiGLUMLP",
148
170
  "LigerTVDLoss",
149
171
  "LigerKLDIVLoss",
150
172
  "LigerMultiTokenAttention",
@@ -167,6 +189,7 @@ if _TRANSFORMERS_AVAILABLE:
167
189
  "apply_liger_kernel_to_glm4",
168
190
  "apply_liger_kernel_to_glm4v",
169
191
  "apply_liger_kernel_to_glm4v_moe",
192
+ "apply_liger_kernel_to_gpt_oss",
170
193
  "apply_liger_kernel_to_granite",
171
194
  "apply_liger_kernel_to_internvl",
172
195
  "apply_liger_kernel_to_llama",
@@ -176,6 +199,7 @@ if _TRANSFORMERS_AVAILABLE:
176
199
  "apply_liger_kernel_to_mixtral",
177
200
  "apply_liger_kernel_to_mllama",
178
201
  "apply_liger_kernel_to_olmo2",
202
+ "apply_liger_kernel_to_olmo3",
179
203
  "apply_liger_kernel_to_paligemma",
180
204
  "apply_liger_kernel_to_phi3",
181
205
  "apply_liger_kernel_to_qwen2",
@@ -183,6 +207,12 @@ if _TRANSFORMERS_AVAILABLE:
183
207
  "apply_liger_kernel_to_qwen2_vl",
184
208
  "apply_liger_kernel_to_qwen3",
185
209
  "apply_liger_kernel_to_qwen3_moe",
210
+ "apply_liger_kernel_to_qwen3_next",
211
+ "apply_liger_kernel_to_qwen3_vl",
212
+ "apply_liger_kernel_to_qwen3_vl_moe",
186
213
  "apply_liger_kernel_to_smollm3",
214
+ "apply_liger_kernel_to_smolvlm",
215
+ "apply_liger_kernel_to_hunyuan_v1_dense",
216
+ "apply_liger_kernel_to_hunyuan_v1_moe",
187
217
  ]
188
218
  )
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import logging
2
3
 
3
4
  from transformers import AutoConfig
4
5
  from transformers import AutoModelForCausalLM
@@ -6,6 +7,8 @@ from transformers import AutoModelForCausalLM
6
7
  from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
8
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  def _get_model_config(model_dir, **model_init_kwargs):
11
14
  config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
@@ -36,3 +39,21 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
36
39
  applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
37
40
 
38
41
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
42
+
43
+ @classmethod
44
+ def from_config(cls, config, **kwargs):
45
+ model_type = getattr(config, "model_type", None)
46
+ if not model_type:
47
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
48
+ return
49
+ model_type = config.model_type
50
+
51
+ _apply_liger_kernel(model_type, **kwargs)
52
+
53
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
54
+ # model initialization errors otherwise
55
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
56
+ apply_fn_signature = inspect.signature(apply_fn)
57
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
58
+
59
+ return super().from_config(config, **applicable_kwargs)
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
5
+ from liger_kernel.ops import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.dyt import LigerDyTFunction
4
+ from liger_kernel.ops import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
@@ -3,7 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
6
+ from liger_kernel.ops import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
@@ -1,24 +1,35 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Optional
2
3
 
3
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.dyt import LigerDyTFunction
5
- from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
6
- from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
7
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
- from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
9
- from liger_kernel.ops.geglu import LigerGELUMulFunction
10
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
11
- from liger_kernel.ops.jsd import LigerJSDFunction
12
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
13
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
14
- from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
15
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
16
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
17
- from liger_kernel.ops.rope import LigerRopeFunction
18
- from liger_kernel.ops.softmax import LigerSoftmaxFunction
19
- from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
20
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
21
- from liger_kernel.ops.tvd import LigerTVDLossFunction
4
+ import torch
5
+
6
+ from liger_kernel.ops import LigerCrossEntropyFunction
7
+ from liger_kernel.ops import LigerDyTFunction
8
+ from liger_kernel.ops import LigerFusedAddRMSNormFunction
9
+ from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
10
+ from liger_kernel.ops import LigerFusedLinearJSDFunction
11
+ from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
12
+ from liger_kernel.ops import LigerGELUMulFunction
13
+ from liger_kernel.ops import LigerGroupNormFunction
14
+ from liger_kernel.ops import LigerJSDFunction
15
+ from liger_kernel.ops import LigerKLDivLossFunction
16
+ from liger_kernel.ops import LigerLayerNormFunction
17
+ from liger_kernel.ops import LigerMultiTokenAttentionFunction
18
+ from liger_kernel.ops import LigerPolyNormFunction
19
+ from liger_kernel.ops import LigerQwen2VLMRopeFunction
20
+ from liger_kernel.ops import LigerRMSNormFunction
21
+ from liger_kernel.ops import LigerRopeFunction
22
+ from liger_kernel.ops import LigerSiLUMulFunction
23
+ from liger_kernel.ops import LigerSoftmaxFunction
24
+ from liger_kernel.ops import LigerSparsemaxFunction
25
+ from liger_kernel.ops import LigerTVDLossFunction
26
+
27
+
28
+ @dataclass
29
+ class CrossEntropyOutput:
30
+ loss: torch.Tensor
31
+ z_loss: Optional[torch.Tensor] = None
32
+ token_accuracy: Optional[torch.Tensor] = None
22
33
 
23
34
 
24
35
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
@@ -35,8 +46,9 @@ def liger_cross_entropy(
35
46
  lse_square_scale: float = 0.0,
36
47
  softcap: Optional[float] = None,
37
48
  return_z_loss: bool = False,
49
+ return_token_accuracy: bool = False,
38
50
  ):
39
- loss, z_loss = LigerCrossEntropyFunction.apply(
51
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
40
52
  input,
41
53
  target,
42
54
  weight,
@@ -46,10 +58,13 @@ def liger_cross_entropy(
46
58
  reduction,
47
59
  softcap,
48
60
  return_z_loss,
61
+ return_token_accuracy,
49
62
  )
50
- if not return_z_loss:
63
+
64
+ if not return_z_loss and not return_token_accuracy:
51
65
  return loss
52
- return loss, z_loss
66
+
67
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
53
68
 
54
69
 
55
70
  def liger_fused_linear_cross_entropy(
@@ -66,8 +81,9 @@ def liger_fused_linear_cross_entropy(
66
81
  return_z_loss: bool = False,
67
82
  accum_dtype=None,
68
83
  use_token_scaling: bool = False,
84
+ return_token_accuracy: bool = False,
69
85
  ):
70
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
86
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
71
87
  input,
72
88
  weight,
73
89
  target,
@@ -81,10 +97,13 @@ def liger_fused_linear_cross_entropy(
81
97
  return_z_loss,
82
98
  accum_dtype,
83
99
  use_token_scaling,
100
+ return_token_accuracy,
84
101
  )
85
- if not return_z_loss:
102
+
103
+ if not return_z_loss and not return_token_accuracy:
86
104
  return loss
87
- return loss, z_loss
105
+
106
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
88
107
 
89
108
 
90
109
  def liger_fused_linear_jsd(
@@ -258,6 +277,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
258
277
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
259
278
 
260
279
 
280
+ def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
281
+ return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
282
+
283
+
261
284
  def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
262
285
  return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
263
286
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
4
+ from liger_kernel.ops import LigerFusedAddRMSNormFunction
5
5
 
6
6
 
7
7
  class LigerFusedAddRMSNorm(nn.Module):
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
+ from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
@@ -17,6 +18,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
17
18
  return_z_loss: bool = False,
18
19
  accum_dtype: Optional[torch.dtype] = None,
19
20
  use_token_scaling: bool = False,
21
+ return_token_accuracy: bool = False,
20
22
  ):
21
23
  super().__init__()
22
24
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -37,9 +39,10 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
37
39
  self.return_z_loss = return_z_loss
38
40
  self.accum_dtype = accum_dtype
39
41
  self.use_token_scaling = use_token_scaling
42
+ self.return_token_accuracy = return_token_accuracy
40
43
 
41
44
  def forward(self, lin_weight, _input, target, bias=None):
42
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
45
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
43
46
  _input,
44
47
  lin_weight,
45
48
  target,
@@ -53,7 +56,9 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
53
56
  self.return_z_loss,
54
57
  self.accum_dtype,
55
58
  self.use_token_scaling,
59
+ self.return_token_accuracy,
56
60
  )
57
- if not self.return_z_loss:
61
+ if not self.return_z_loss and not self.return_token_accuracy:
58
62
  return loss
59
- return loss, z_loss
63
+
64
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
5
+ from liger_kernel.ops import LigerFusedLinearJSDFunction
6
6
 
7
7
 
8
8
  class LigerFusedLinearJSD(torch.nn.Module):
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import torch
6
6
  import torch.nn as nn
7
7
 
8
- from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
8
+ from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
9
9
 
10
10
 
11
11
  class LigerFusedNeighborhoodAttention(nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.geglu import LigerGELUMulFunction
3
+ from liger_kernel.ops import LigerGELUMulFunction
4
4
 
5
5
 
6
6
  class LigerGEGLUMLP(nn.Module):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
4
+ from liger_kernel.ops import LigerGroupNormFunction
5
5
 
6
6
 
7
7
  class LigerGroupNorm(nn.Module):
@@ -1,4 +1,7 @@
1
- from liger_kernel.ops.grpo_loss import GrpoLossFunction
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
+ from liger_kernel.ops import GrpoLossFunction
2
5
 
3
6
 
4
7
  def triton_grpo_loss(
@@ -13,12 +16,20 @@ def triton_grpo_loss(
13
16
  eps_low=0.2,
14
17
  eps_high=0.4,
15
18
  inplace=True,
19
+ loss_type="dapo",
20
+ max_completion_length=None,
21
+ importance_sampling_level="token",
22
+ reduce=False,
16
23
  ):
17
24
  assert logits is not None and completion_ids is not None and advantages is not None, (
18
25
  "must provide logits、completion_ids and advantages"
19
26
  )
27
+ if importance_sampling_level != "token":
28
+ raise ValueError(
29
+ f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
30
+ )
20
31
 
21
- return GrpoLossFunction.apply(
32
+ per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
22
33
  logits,
23
34
  old_logp,
24
35
  ref_logp,
@@ -31,6 +42,50 @@ def triton_grpo_loss(
31
42
  eps_high,
32
43
  inplace,
33
44
  )
45
+ if not reduce:
46
+ return per_token_loss, per_token_kl, is_clipped
47
+
48
+ loss = _reduce_grpo_loss(
49
+ per_token_loss,
50
+ completion_mask,
51
+ loss_type=loss_type,
52
+ max_completion_length=max_completion_length,
53
+ )
54
+
55
+ metrics = []
56
+ if beta != 0.0 and per_token_kl is not None:
57
+ metrics.append(_masked_mean(per_token_kl, completion_mask))
58
+ metrics.append(_masked_mean(is_clipped.float(), completion_mask))
59
+ return loss, metrics
60
+
61
+
62
+ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
63
+ mask = completion_mask
64
+ if mask is None:
65
+ mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
66
+ mask = mask.to(per_token_loss.dtype)
67
+
68
+ if loss_type == "grpo":
69
+ per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
70
+ return per_seq.mean()
71
+ if loss_type == "bnpo":
72
+ return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
73
+ if loss_type == "dr_grpo":
74
+ if max_completion_length is None:
75
+ raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
76
+ batch = per_token_loss.shape[0]
77
+ return (per_token_loss * mask).sum() / (batch * max_completion_length)
78
+ if loss_type == "dapo":
79
+ normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
80
+ return (per_token_loss * mask).sum() / normalizer
81
+ raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
82
+
83
+
84
+ def _masked_mean(values, mask):
85
+ if mask is None:
86
+ mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
87
+ mask = mask.to(values.dtype)
88
+ return (values * mask).sum() / mask.sum().clamp(min=1.0)
34
89
 
35
90
 
36
91
  # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.jsd import LigerJSDFunction
5
+ from liger_kernel.ops import LigerJSDFunction
6
6
 
7
7
 
8
8
  class LigerJSD(torch.nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
3
+ from liger_kernel.ops import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
4
+ from liger_kernel.ops import LigerLayerNormFunction
5
5
 
6
6
 
7
7
  class LigerLayerNorm(nn.Module):
@@ -5,7 +5,7 @@ Supports both text and vision RoPE variants with fused operations for optimal pe
5
5
 
6
6
  import torch
7
7
 
8
- from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
8
+ from liger_kernel.ops import LigerLlama4RopeFunction
9
9
 
10
10
 
11
11
  def liger_llama4_text_rotary_pos_emb(
@@ -4,12 +4,12 @@ from typing import Union
4
4
 
5
5
  import torch
6
6
 
7
- from transformers.modeling_outputs import CausalLMOutputWithPast
8
-
9
7
  if TYPE_CHECKING:
10
8
  from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
11
9
 
12
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
13
 
14
14
 
15
15
  def lce_forward(
@@ -26,8 +26,9 @@ def lce_forward(
26
26
  cache_position: Optional[torch.LongTensor] = None,
27
27
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
28
  skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
29
30
  **kwargs,
30
- ) -> Union[tuple, CausalLMOutputWithPast]:
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
31
32
  r"""
32
33
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
34
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -54,6 +55,7 @@ def lce_forward(
54
55
  output_hidden_states = (
55
56
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
57
  )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
59
 
58
60
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59
61
  outputs = self.model(
@@ -77,6 +79,8 @@ def lce_forward(
77
79
  shift_labels = kwargs.pop("shift_labels", None)
78
80
  logits = None
79
81
  loss = None
82
+ token_accuracy = None
83
+
80
84
  # if in training mode, don't materialize logits
81
85
  if skip_logits and labels is None:
82
86
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -85,8 +89,9 @@ def lce_forward(
85
89
  # By default, if in training mode, don't materialize logits
86
90
  skip_logits = self.training and labels is not None
87
91
 
92
+ # Compute loss
88
93
  if skip_logits:
89
- loss = LigerForCausalLMLoss(
94
+ result = LigerForCausalLMLoss(
90
95
  hidden_states=kept_hidden_states,
91
96
  lm_head_weight=self.lm_head.weight,
92
97
  labels=labels,
@@ -94,15 +99,24 @@ def lce_forward(
94
99
  hidden_size=self.config.hidden_size,
95
100
  **kwargs,
96
101
  )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
97
103
  else:
98
104
  logits = self.lm_head(kept_hidden_states)
99
105
  if labels is not None or shift_labels is not None:
100
106
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
101
107
 
102
- return CausalLMOutputWithPast(
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
103
116
  loss=loss,
104
117
  logits=logits,
105
118
  past_key_values=outputs.past_key_values,
106
119
  hidden_states=outputs.hidden_states,
107
120
  attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
108
122
  )