liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
1
6
|
import torch
|
|
2
7
|
import torch.nn.functional as F
|
|
3
8
|
|
|
@@ -6,34 +11,50 @@ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinear
|
|
|
6
11
|
|
|
7
12
|
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
8
13
|
@staticmethod
|
|
9
|
-
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
|
14
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100):
|
|
10
15
|
"""
|
|
11
16
|
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
|
12
17
|
Args:
|
|
13
18
|
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
14
19
|
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
15
20
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
21
|
+
target (torch.Tensor): Target labels for masking. Shape: (chunk_size,).
|
|
22
|
+
ignore_index (int): Index to ignore in loss computation.
|
|
16
23
|
Returns:
|
|
17
24
|
torch.Tensor: Jensen-Shannon Divergence loss
|
|
25
|
+
Note:
|
|
26
|
+
- Uses reduction="none" to preserve per-token losses for masking
|
|
27
|
+
- KL divergence requires summing over vocab dimension (not mean)
|
|
28
|
+
- Masking excludes padding/prompt tokens from loss computation
|
|
18
29
|
"""
|
|
19
30
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
31
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
32
|
|
|
22
33
|
if beta == 0:
|
|
23
|
-
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="
|
|
34
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
|
24
35
|
elif beta == 1:
|
|
25
|
-
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="
|
|
36
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
|
26
37
|
else:
|
|
27
38
|
# Compute probabilities (only required for mean calculation)
|
|
28
|
-
|
|
29
|
-
|
|
39
|
+
log_mean_probs = torch.logsumexp(
|
|
40
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
41
|
+
)
|
|
30
42
|
|
|
31
|
-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="
|
|
32
|
-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="
|
|
43
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
|
|
44
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
|
|
33
45
|
|
|
34
46
|
# JSD is the weighted average of the KL divergences
|
|
35
47
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
36
|
-
|
|
48
|
+
|
|
49
|
+
# Sum over vocab dimension (KL divergence definition)
|
|
50
|
+
jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,)
|
|
51
|
+
|
|
52
|
+
# Apply ignore_index mask
|
|
53
|
+
if target is not None:
|
|
54
|
+
mask = target != ignore_index
|
|
55
|
+
jsd_loss = jsd_loss.masked_fill(~mask, 0.0)
|
|
56
|
+
|
|
57
|
+
return jsd_loss.sum()
|
|
37
58
|
|
|
38
59
|
@classmethod
|
|
39
60
|
def forward(
|
|
@@ -53,6 +74,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
53
74
|
temperature: float = 1.0,
|
|
54
75
|
compiled: bool = True,
|
|
55
76
|
chunk_size: int = 1024,
|
|
77
|
+
return_soft_hard_loss: bool = False,
|
|
56
78
|
):
|
|
57
79
|
"""
|
|
58
80
|
Fused linear layer with JSD distillation loss.
|
|
@@ -69,8 +91,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
69
91
|
temperature (float): Temperature for softening/sharpening distributions
|
|
70
92
|
compiled (bool): Whether to use torch compile
|
|
71
93
|
chunk_size (int): Size of chunks for processing.
|
|
94
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
72
95
|
Returns:
|
|
73
|
-
torch.Tensor: Computed loss
|
|
96
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
74
97
|
"""
|
|
75
98
|
return super().forward(
|
|
76
99
|
cls=cls,
|
|
@@ -89,11 +112,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
89
112
|
ignore_index=ignore_index,
|
|
90
113
|
temperature=temperature,
|
|
91
114
|
compiled=compiled,
|
|
115
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
92
116
|
)
|
|
93
117
|
|
|
94
118
|
@staticmethod
|
|
95
|
-
def backward(ctx, grad_output):
|
|
96
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
119
|
+
def backward(ctx, grad_output, *args):
|
|
120
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
97
121
|
|
|
98
122
|
return (
|
|
99
123
|
*grads,
|
|
@@ -105,6 +129,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
105
129
|
None, # temperature
|
|
106
130
|
None, # compiled
|
|
107
131
|
None, # chunk_size
|
|
132
|
+
None, # return_soft_hard_loss
|
|
108
133
|
)
|
|
109
134
|
|
|
110
135
|
|
|
@@ -122,6 +147,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
122
147
|
temperature: float = 1.0,
|
|
123
148
|
compiled: bool = True,
|
|
124
149
|
chunk_size: int = 1024,
|
|
150
|
+
return_soft_hard_loss: bool = False,
|
|
125
151
|
):
|
|
126
152
|
"""
|
|
127
153
|
Args:
|
|
@@ -132,6 +158,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
132
158
|
compiled (bool): Whether to use torch compile
|
|
133
159
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
134
160
|
chunk_size (int): Size of chunks for processing.
|
|
161
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
135
162
|
"""
|
|
136
163
|
super().__init__()
|
|
137
164
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -142,6 +169,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
142
169
|
self.compiled = compiled
|
|
143
170
|
self.beta = beta
|
|
144
171
|
self.chunk_size = chunk_size
|
|
172
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
145
173
|
|
|
146
174
|
def forward(
|
|
147
175
|
self,
|
|
@@ -152,7 +180,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
152
180
|
true_labels: torch.LongTensor,
|
|
153
181
|
student_bias: torch.Tensor = None,
|
|
154
182
|
teacher_bias: torch.Tensor = None,
|
|
155
|
-
) -> torch.Tensor:
|
|
183
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
156
184
|
"""
|
|
157
185
|
Compute the JSD distillation loss.
|
|
158
186
|
|
|
@@ -164,7 +192,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
164
192
|
true_labels (torch.LongTensor): Target labels tensor
|
|
165
193
|
|
|
166
194
|
Returns:
|
|
167
|
-
torch.Tensor
|
|
195
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
196
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
197
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
168
198
|
"""
|
|
169
199
|
return LigerFusedLinearJSDFunction.apply(
|
|
170
200
|
student_input,
|
|
@@ -181,4 +211,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
181
211
|
self.temperature,
|
|
182
212
|
self.compiled,
|
|
183
213
|
self.chunk_size,
|
|
214
|
+
self.return_soft_hard_loss,
|
|
184
215
|
)
|
liger_kernel/ops/__init__.py
CHANGED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Liger-Kernel operators with automatic vendor-specific replacement.
|
|
3
|
+
|
|
4
|
+
This module provides two ways to import operators:
|
|
5
|
+
|
|
6
|
+
1. Import from this package (recommended for Function classes):
|
|
7
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
8
|
+
|
|
9
|
+
This automatically uses vendor-specific implementation if available.
|
|
10
|
+
|
|
11
|
+
2. Import from submodules (for kernel functions or specific access):
|
|
12
|
+
from liger_kernel.ops.geglu import geglu_forward, geglu_backward
|
|
13
|
+
|
|
14
|
+
This always uses the default implementation (no auto-replacement).
|
|
15
|
+
|
|
16
|
+
The replacement mechanism:
|
|
17
|
+
1. Default implementations are imported from individual modules (e.g., geglu.py)
|
|
18
|
+
2. On module load, device is detected via infer_device()
|
|
19
|
+
3. If running on a supported vendor device (npu, xpu, etc.), the default
|
|
20
|
+
implementations are replaced with vendor-specific ones
|
|
21
|
+
4. All subsequent imports from this package get the replaced versions
|
|
22
|
+
|
|
23
|
+
Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
|
|
24
|
+
are NOT affected by the replacement mechanism.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# =============================================================================
|
|
28
|
+
# Import default implementations
|
|
29
|
+
# Both Function classes and kernel functions are imported here.
|
|
30
|
+
# All of these can be replaced by vendor-specific implementations.
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
|
|
34
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
|
|
35
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
|
|
36
|
+
from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
|
|
37
|
+
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
|
|
38
|
+
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
|
|
39
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
|
|
40
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
|
|
41
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
|
|
42
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
|
|
43
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
|
|
44
|
+
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
|
|
45
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
|
|
46
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
|
|
47
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
|
|
48
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
|
|
49
|
+
from liger_kernel.ops.geglu import geglu_backward # noqa: F401
|
|
50
|
+
from liger_kernel.ops.geglu import geglu_forward # noqa: F401
|
|
51
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
|
|
52
|
+
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
|
|
53
|
+
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
|
|
54
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
|
|
55
|
+
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
|
|
56
|
+
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
|
|
57
|
+
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
|
|
58
|
+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
|
|
59
|
+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
|
|
60
|
+
from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
|
|
61
|
+
from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
|
|
62
|
+
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
|
|
63
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
|
|
64
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
|
|
65
|
+
from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
|
|
66
|
+
from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
|
|
67
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
|
|
68
|
+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
|
|
69
|
+
from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
|
|
70
|
+
from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
|
|
71
|
+
from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
|
|
72
|
+
from liger_kernel.ops.rope import rope_backward # noqa: F401
|
|
73
|
+
from liger_kernel.ops.rope import rope_forward # noqa: F401
|
|
74
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
|
|
75
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
|
|
76
|
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
|
|
77
|
+
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
|
|
78
|
+
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
|
|
79
|
+
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
|
|
80
|
+
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
|
|
81
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
|
|
82
|
+
|
|
83
|
+
# NOTE: __all__ is intentionally NOT defined.
|
|
84
|
+
# - Import from this package (liger_kernel.ops) -> subject to vendor replacement
|
|
85
|
+
# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# =============================================================================
|
|
89
|
+
# Vendor-specific replacement logic
|
|
90
|
+
# =============================================================================
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _replace_with_vendor_ops():
|
|
94
|
+
"""
|
|
95
|
+
Replace/add vendor-specific operator implementations.
|
|
96
|
+
|
|
97
|
+
This function is called automatically on module load. It:
|
|
98
|
+
1. Detects the current device (cuda, npu, xpu, etc.)
|
|
99
|
+
2. Looks up the vendor for that device via VENDOR_REGISTRY
|
|
100
|
+
3. Loads and applies vendor-specific implementations
|
|
101
|
+
|
|
102
|
+
Vendor implementations should be placed in:
|
|
103
|
+
liger_kernel/ops/backends/_<vendor>/ops/
|
|
104
|
+
|
|
105
|
+
If the vendor module defines __all__, only those symbols are exported.
|
|
106
|
+
Otherwise, all public symbols (not starting with _) are auto-discovered.
|
|
107
|
+
|
|
108
|
+
Note: Vendor can both override existing ops AND add new vendor-specific ops.
|
|
109
|
+
"""
|
|
110
|
+
from liger_kernel.ops.backends import get_vendor_for_device
|
|
111
|
+
from liger_kernel.utils import infer_device
|
|
112
|
+
|
|
113
|
+
device = infer_device()
|
|
114
|
+
|
|
115
|
+
# Look up vendor info for this device
|
|
116
|
+
vendor_info = get_vendor_for_device(device)
|
|
117
|
+
if vendor_info is None:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
import importlib
|
|
122
|
+
|
|
123
|
+
vendor_ops = importlib.import_module(vendor_info.module_path)
|
|
124
|
+
|
|
125
|
+
# Get names to export: use __all__ if defined, otherwise auto-discover
|
|
126
|
+
names_to_export = getattr(vendor_ops, "__all__", None)
|
|
127
|
+
|
|
128
|
+
if names_to_export is None:
|
|
129
|
+
# Auto-discover: find all public symbols (classes and functions)
|
|
130
|
+
names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
|
|
131
|
+
|
|
132
|
+
# Replace or add to this module's globals
|
|
133
|
+
for name in names_to_export:
|
|
134
|
+
globals()[name] = getattr(vendor_ops, name)
|
|
135
|
+
|
|
136
|
+
except ImportError:
|
|
137
|
+
# Vendor module not available, use default implementations
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
_replace_with_vendor_ops()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Adding a New Vendor Backend
|
|
2
|
+
|
|
3
|
+
This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
|
|
4
|
+
|
|
5
|
+
## Concepts
|
|
6
|
+
|
|
7
|
+
- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
|
|
8
|
+
- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
|
|
9
|
+
- **VendorInfo**: Defines the mapping between vendor and device
|
|
10
|
+
|
|
11
|
+
## Directory Structure
|
|
12
|
+
|
|
13
|
+
```
|
|
14
|
+
backends/
|
|
15
|
+
├── README.md
|
|
16
|
+
├── __init__.py
|
|
17
|
+
├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
|
|
18
|
+
├── _ascend/ # Ascend (Huawei) vendor - supports NPU
|
|
19
|
+
│ ├── __init__.py # Registers VendorInfo for NPU
|
|
20
|
+
│ └── ops/
|
|
21
|
+
│ ├── __init__.py # Exports vendor-specific implementations
|
|
22
|
+
│ └── geglu.py # NPU-specific GEGLU implementation
|
|
23
|
+
└── _<vendor>/ # Your new vendor backend
|
|
24
|
+
└── ...
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## How It Works
|
|
28
|
+
|
|
29
|
+
1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
|
|
30
|
+
2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
|
|
31
|
+
3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
|
|
32
|
+
4. It detects the current device via `infer_device()` and looks up the vendor
|
|
33
|
+
5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
|
|
34
|
+
|
|
35
|
+
## Adding a New Vendor
|
|
36
|
+
|
|
37
|
+
### Step 1: Create Directory Structure
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
mkdir -p backends/_<vendor>/ops
|
|
41
|
+
touch backends/_<vendor>/__init__.py
|
|
42
|
+
touch backends/_<vendor>/ops/__init__.py
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
### Step 2: Register Your Vendor
|
|
46
|
+
|
|
47
|
+
In `backends/_<vendor>/__init__.py`, register your vendor:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
"""
|
|
51
|
+
<Vendor> backend for Liger-Kernel.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
|
|
55
|
+
|
|
56
|
+
register_vendor(
|
|
57
|
+
VendorInfo(
|
|
58
|
+
vendor="<vendor>",
|
|
59
|
+
device="<device>",
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
### Step 3: Ensure Device Detection Works
|
|
66
|
+
|
|
67
|
+
Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
def infer_device():
|
|
71
|
+
if torch.cuda.is_available():
|
|
72
|
+
return "cuda"
|
|
73
|
+
if is_npu_available():
|
|
74
|
+
return "npu"
|
|
75
|
+
# Add your device detection here
|
|
76
|
+
if is_<device>_available():
|
|
77
|
+
return "<device>"
|
|
78
|
+
return "cpu"
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Step 4: Implement Vendor-Specific Operators
|
|
82
|
+
|
|
83
|
+
Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
89
|
+
"""
|
|
90
|
+
Vendor-specific LigerGELUMulFunction implementation.
|
|
91
|
+
"""
|
|
92
|
+
@staticmethod
|
|
93
|
+
def forward(ctx, a, b):
|
|
94
|
+
# Your vendor-specific forward implementation
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def backward(ctx, dc):
|
|
99
|
+
# Your vendor-specific backward implementation
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
# Optional: vendor-specific kernel functions
|
|
103
|
+
def geglu_forward_vendor(a, b):
|
|
104
|
+
...
|
|
105
|
+
|
|
106
|
+
def geglu_backward_vendor(a, b, dc):
|
|
107
|
+
...
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Step 5: Export in `ops/__init__.py`
|
|
111
|
+
|
|
112
|
+
In `backends/_<vendor>/ops/__init__.py`, export your implementations:
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
"""
|
|
116
|
+
<Vendor>-specific operator implementations.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
from .<module> import (
|
|
120
|
+
LigerGELUMulFunction,
|
|
121
|
+
geglu_forward_vendor as geglu_forward, # Rename to match default API
|
|
122
|
+
geglu_backward_vendor as geglu_backward,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Explicitly declare what to export (recommended)
|
|
126
|
+
__all__ = [
|
|
127
|
+
"LigerGELUMulFunction",
|
|
128
|
+
"geglu_forward",
|
|
129
|
+
"geglu_backward",
|
|
130
|
+
]
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
## Key Points
|
|
134
|
+
|
|
135
|
+
### Incremental Override
|
|
136
|
+
|
|
137
|
+
You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
|
|
138
|
+
|
|
139
|
+
### Vendor-Specific Additions
|
|
140
|
+
|
|
141
|
+
Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
|
|
142
|
+
|
|
143
|
+
### Naming Convention
|
|
144
|
+
|
|
145
|
+
- Use the **same class/function names** as the default implementations for overrides
|
|
146
|
+
- This allows seamless replacement without changing user code
|
|
147
|
+
- Use `as` imports to rename if your internal naming differs
|
|
148
|
+
|
|
149
|
+
## Example: Ascend NPU Backend
|
|
150
|
+
|
|
151
|
+
See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pkgutil
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
|
|
5
|
+
from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
|
|
6
|
+
from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
|
|
7
|
+
from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
|
|
8
|
+
|
|
9
|
+
# Auto-import all _<vendor> subpackages to trigger registration
|
|
10
|
+
# Each vendor's __init__.py calls register_vendor() when imported
|
|
11
|
+
for _, modname, ispkg in pkgutil.iter_modules(__path__):
|
|
12
|
+
if ispkg and modname.startswith("_"):
|
|
13
|
+
importlib.import_module(f"{__name__}.{modname}")
|