liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None
liger_kernel/ops/utils.py CHANGED
@@ -13,13 +13,17 @@ Modifications made by Yanning Chen, 2024.
13
13
  import functools
14
14
  import importlib
15
15
  import operator
16
+
16
17
  from typing import Callable
17
18
 
18
19
  import torch
19
20
  import triton
20
21
  import triton.language as tl
22
+
21
23
  from packaging.version import Version
22
24
 
25
+ from liger_kernel.utils import infer_device
26
+
23
27
 
24
28
  def is_hip() -> bool:
25
29
  return torch.version.hip is not None
@@ -45,8 +49,7 @@ def calculate_settings(n):
45
49
  BLOCK_SIZE = triton.next_power_of_2(n)
46
50
  if BLOCK_SIZE > MAX_FUSED_SIZE:
47
51
  raise RuntimeError(
48
- f"Cannot launch Triton kernel since n = {n} exceeds "
49
- f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
52
+ f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
50
53
  )
51
54
 
52
55
  num_warps = 4
@@ -69,10 +72,11 @@ def compare_version(package: str, operator: Callable, target: str):
69
72
 
70
73
 
71
74
  def get_amp_custom_fwd_bwd() -> Callable:
75
+ device = infer_device()
72
76
  if compare_version("torch", operator.ge, "2.4.0"):
73
77
  return (
74
- functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
- functools.partial(torch.amp.custom_bwd, device_type="cuda"),
78
+ functools.partial(torch.amp.custom_fwd, device_type=device),
79
+ functools.partial(torch.amp.custom_bwd, device_type=device),
76
80
  )
77
81
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
82
 
@@ -1,31 +1,206 @@
1
- from liger_kernel.transformers.auto_model import ( # noqa: F401
2
- AutoLigerKernelForCausalLM,
3
- )
1
+ import importlib
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Always-safe imports (independent of 'transformers')
4
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5
- from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
- LigerFusedLinearCrossEntropyLoss,
7
- )
7
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
9
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
8
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
12
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13
+ from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
11
14
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13
- _apply_liger_kernel,
14
- _apply_liger_kernel_to_instance,
15
- apply_liger_kernel_to_gemma,
16
- apply_liger_kernel_to_gemma2,
17
- apply_liger_kernel_to_llama,
18
- apply_liger_kernel_to_mistral,
19
- apply_liger_kernel_to_mixtral,
20
- apply_liger_kernel_to_mllama,
21
- apply_liger_kernel_to_phi3,
22
- apply_liger_kernel_to_qwen2,
23
- apply_liger_kernel_to_qwen2_vl,
24
- )
15
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
+ from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
+ from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
25
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27
- from liger_kernel.transformers.swiglu import ( # noqa: F401
28
- LigerBlockSparseTop2MLP,
29
- LigerPhi3SwiGLUMLP,
30
- LigerSwiGLUMLP,
31
- )
21
+ from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
22
+ from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
23
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
24
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
26
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
29
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
30
+
31
+ # Static-only imports for IDEs and type checkers
32
+ if TYPE_CHECKING:
33
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
35
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
37
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
38
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
39
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
40
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
41
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
42
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
48
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
49
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
50
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
51
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
52
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
53
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
54
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
56
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
57
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
58
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
59
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
61
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
62
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
63
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
65
+
66
+
67
+ # Check if 'transformers' is installed
68
+ try:
69
+ import transformers # noqa: F401
70
+
71
+ _TRANSFORMERS_AVAILABLE = True
72
+ except ImportError:
73
+ _TRANSFORMERS_AVAILABLE = False
74
+
75
+
76
+ def is_transformers_available() -> bool:
77
+ """
78
+ Returns True if the 'transformers' package is available.
79
+ Useful for conditional logic in downstream code.
80
+ """
81
+ return _TRANSFORMERS_AVAILABLE
82
+
83
+
84
+ def __getattr__(name: str):
85
+ """
86
+ Handles lazy access to transformer-dependent attributes.
87
+ If 'transformers' is not installed, raises a user-friendly ImportError.
88
+ """
89
+ if not _TRANSFORMERS_AVAILABLE:
90
+ raise ImportError(
91
+ f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
92
+ f"Please install it with `pip install transformers` to use this functionality."
93
+ )
94
+
95
+ if name == "AutoLigerKernelForCausalLM":
96
+ module = importlib.import_module("liger_kernel.transformers.auto_model")
97
+ return getattr(module, name)
98
+
99
+ monkey_patch_symbols = {
100
+ "_apply_liger_kernel",
101
+ "_apply_liger_kernel_to_instance",
102
+ "apply_liger_kernel_to_falcon_h1",
103
+ "apply_liger_kernel_to_gemma",
104
+ "apply_liger_kernel_to_gemma2",
105
+ "apply_liger_kernel_to_gemma3",
106
+ "apply_liger_kernel_to_gemma3_text",
107
+ "apply_liger_kernel_to_glm4",
108
+ "apply_liger_kernel_to_glm4v",
109
+ "apply_liger_kernel_to_glm4v_moe",
110
+ "apply_liger_kernel_to_granite",
111
+ "apply_liger_kernel_to_internvl",
112
+ "apply_liger_kernel_to_llama",
113
+ "apply_liger_kernel_to_llava",
114
+ "apply_liger_kernel_to_llama4",
115
+ "apply_liger_kernel_to_mistral",
116
+ "apply_liger_kernel_to_mixtral",
117
+ "apply_liger_kernel_to_mllama",
118
+ "apply_liger_kernel_to_olmo2",
119
+ "apply_liger_kernel_to_paligemma",
120
+ "apply_liger_kernel_to_phi3",
121
+ "apply_liger_kernel_to_qwen2",
122
+ "apply_liger_kernel_to_qwen2_5_vl",
123
+ "apply_liger_kernel_to_qwen2_vl",
124
+ "apply_liger_kernel_to_qwen3",
125
+ "apply_liger_kernel_to_qwen3_moe",
126
+ "apply_liger_kernel_to_qwen3_next",
127
+ "apply_liger_kernel_to_qwen3_vl",
128
+ "apply_liger_kernel_to_qwen3_vl_moe",
129
+ "apply_liger_kernel_to_smollm3",
130
+ "apply_liger_kernel_to_smolvlm",
131
+ }
132
+
133
+ if name in monkey_patch_symbols:
134
+ module = importlib.import_module("liger_kernel.transformers.monkey_patch")
135
+ return getattr(module, name)
136
+
137
+ raise AttributeError(f"module {__name__} has no attribute {name}")
138
+
139
+
140
+ # Shared symbols in all environments
141
+ __all__ = [
142
+ "is_transformers_available",
143
+ "LigerCrossEntropyLoss",
144
+ "LigerDyT",
145
+ "LigerFusedLinearCrossEntropyLoss",
146
+ "LigerFusedLinearJSD",
147
+ "LigerGEGLUMLP",
148
+ "LigerJSD",
149
+ "LigerLayerNorm",
150
+ "LigerFusedAddRMSNorm",
151
+ "LigerPolyNorm",
152
+ "LigerRMSNorm",
153
+ "liger_rotary_pos_emb",
154
+ "liger_llama4_text_rotary_pos_emb",
155
+ "liger_llama4_vision_rotary_pos_emb",
156
+ "LigerBlockSparseTop2MLP",
157
+ "LigerPhi3SwiGLUMLP",
158
+ "LigerQwen3MoeSwiGLUMLP",
159
+ "LigerSwiGLUMLP",
160
+ "LigerTiledGEGLUMLP",
161
+ "LigerTiledSwiGLUMLP",
162
+ "LigerTVDLoss",
163
+ "LigerKLDIVLoss",
164
+ "LigerMultiTokenAttention",
165
+ "LigerSoftmax",
166
+ "LigerSparsemax",
167
+ ]
168
+
169
+ # Add transformer-dependent symbols only if available
170
+ if _TRANSFORMERS_AVAILABLE:
171
+ __all__.extend(
172
+ [
173
+ "AutoLigerKernelForCausalLM",
174
+ "_apply_liger_kernel",
175
+ "_apply_liger_kernel_to_instance",
176
+ "apply_liger_kernel_to_falcon_h1",
177
+ "apply_liger_kernel_to_gemma",
178
+ "apply_liger_kernel_to_gemma2",
179
+ "apply_liger_kernel_to_gemma3",
180
+ "apply_liger_kernel_to_gemma3_text",
181
+ "apply_liger_kernel_to_glm4",
182
+ "apply_liger_kernel_to_glm4v",
183
+ "apply_liger_kernel_to_glm4v_moe",
184
+ "apply_liger_kernel_to_granite",
185
+ "apply_liger_kernel_to_internvl",
186
+ "apply_liger_kernel_to_llama",
187
+ "apply_liger_kernel_to_llava",
188
+ "apply_liger_kernel_to_llama4",
189
+ "apply_liger_kernel_to_mistral",
190
+ "apply_liger_kernel_to_mixtral",
191
+ "apply_liger_kernel_to_mllama",
192
+ "apply_liger_kernel_to_olmo2",
193
+ "apply_liger_kernel_to_paligemma",
194
+ "apply_liger_kernel_to_phi3",
195
+ "apply_liger_kernel_to_qwen2",
196
+ "apply_liger_kernel_to_qwen2_5_vl",
197
+ "apply_liger_kernel_to_qwen2_vl",
198
+ "apply_liger_kernel_to_qwen3",
199
+ "apply_liger_kernel_to_qwen3_moe",
200
+ "apply_liger_kernel_to_qwen3_next",
201
+ "apply_liger_kernel_to_qwen3_vl",
202
+ "apply_liger_kernel_to_qwen3_vl_moe",
203
+ "apply_liger_kernel_to_smollm3",
204
+ "apply_liger_kernel_to_smolvlm",
205
+ ]
206
+ )
@@ -1,11 +1,10 @@
1
1
  import inspect
2
2
 
3
- from transformers import AutoConfig, AutoModelForCausalLM
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModelForCausalLM
4
5
 
5
- from liger_kernel.transformers.monkey_patch import (
6
- MODEL_TYPE_TO_APPLY_LIGER_FN,
7
- _apply_liger_kernel,
8
- )
6
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
7
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
9
8
 
10
9
 
11
10
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -34,12 +33,6 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
34
33
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
34
  apply_fn_signature = inspect.signature(apply_fn)
36
35
 
37
- applicable_kwargs = {
38
- key: value
39
- for key, value in kwargs.items()
40
- if key not in apply_fn_signature.parameters
41
- }
36
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
42
37
 
43
- return super().from_pretrained(
44
- pretrained_model_name_or_path, *model_args, **applicable_kwargs
45
- )
38
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@@ -1,43 +1,56 @@
1
- import torch.nn as nn
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
4
7
 
5
8
 
6
- class LigerCrossEntropyLoss(nn.Module):
9
+ class LigerCrossEntropyLoss(torch.nn.Module):
7
10
  def __init__(
8
11
  self,
9
- ignore_index=-100,
10
- lse_square_scale=0.0,
11
- label_smoothing=0.0,
12
- reduction="mean",
13
- return_z_loss=False,
12
+ weight: Optional[torch.FloatTensor] = None,
13
+ ignore_index: int = -100,
14
+ lse_square_scale: float = 0.0,
15
+ label_smoothing: float = 0.0,
16
+ reduction: str = "mean",
17
+ softcap: Optional[float] = None,
18
+ return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
14
20
  ):
15
21
  super().__init__()
22
+ assert (label_smoothing >= 0) and (label_smoothing <= 1), (
23
+ f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
24
+ )
25
+ assert reduction in {
26
+ "mean",
27
+ "sum",
28
+ "none",
29
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
+ assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
31
+ self.weight = weight
16
32
  self.ignore_index = ignore_index
17
33
  self.lse_square_scale = lse_square_scale
18
34
  self.label_smoothing = label_smoothing
19
35
  self.reduction = reduction
36
+ self.softcap = softcap
20
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
21
39
 
22
- assert (self.label_smoothing >= 0) and (
23
- self.label_smoothing <= 1
24
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
25
- assert self.reduction in {
26
- "mean",
27
- "sum",
28
- "none",
29
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
30
-
31
- def forward(self, _input, target):
32
- loss, z_loss = LigerCrossEntropyFunction.apply(
40
+ def forward(self, _input: torch.Tensor, target: torch.Tensor):
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
33
42
  _input,
34
43
  target,
44
+ self.weight,
35
45
  self.ignore_index,
36
46
  self.lse_square_scale,
37
47
  self.label_smoothing,
38
48
  self.reduction,
49
+ self.softcap,
39
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
40
52
  )
41
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
42
54
  return loss
43
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -0,0 +1,22 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
17
+
18
+ def forward(self, x):
19
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
20
+
21
+ def extra_repr(self):
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2
+
3
+ __all__ = [
4
+ "LigerEmbedding",
5
+ ]
@@ -7,9 +7,7 @@ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):
10
- def __init__(
11
- self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12
- ):
10
+ def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
13
11
  super().__init__()
14
12
  self.num_embeddings = num_embeddings
15
13
  self.embedding_dim = embedding_dim