liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
chosen_logps,
|
|
11
|
+
rejected_logps,
|
|
12
|
+
full_target,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
gamma=0.5,
|
|
15
|
+
label_smoothing=0.0,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Paper: https://arxiv.org/pdf/2405.14734
|
|
19
|
+
|
|
20
|
+
Formula:
|
|
21
|
+
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
|
|
22
|
+
|
|
23
|
+
Where:
|
|
24
|
+
- π_θ(y|x): Policy (model) probability
|
|
25
|
+
- y_w: Chosen sequence
|
|
26
|
+
- y_l: Rejected sequence
|
|
27
|
+
- |y_w|, |y_l|: Sequence lengths
|
|
28
|
+
- σ: Sigmoid function
|
|
29
|
+
- β: beta weight
|
|
30
|
+
- γ: gemma margin term
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
34
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
35
|
+
full_target: Non chunked full target tensor
|
|
36
|
+
beta (float): beta weight
|
|
37
|
+
gamma (float): gemma margin term
|
|
38
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
39
|
+
"""
|
|
40
|
+
logits = beta * (chosen_logps - rejected_logps) - gamma
|
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
42
|
+
full_target.shape[0] // 2
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
chosen_rewards = beta * chosen_logps
|
|
46
|
+
rejected_rewards = beta * rejected_logps
|
|
47
|
+
|
|
48
|
+
return loss, chosen_rewards, rejected_rewards
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def forward(
|
|
52
|
+
cls,
|
|
53
|
+
ctx,
|
|
54
|
+
_input,
|
|
55
|
+
weight,
|
|
56
|
+
target,
|
|
57
|
+
bias=None,
|
|
58
|
+
ignore_index=-100,
|
|
59
|
+
beta=0.1,
|
|
60
|
+
alpha=1.0,
|
|
61
|
+
label_smoothing=0.0,
|
|
62
|
+
compute_nll_loss=False,
|
|
63
|
+
compiled=True,
|
|
64
|
+
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
85
|
+
return super().forward(
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
92
|
+
ignore_index=ignore_index,
|
|
93
|
+
alpha=alpha,
|
|
94
|
+
beta=beta,
|
|
95
|
+
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
97
|
+
compiled=compiled,
|
|
98
|
+
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def backward(ctx, *grad_output):
|
|
104
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
109
|
+
"""
|
|
110
|
+
Fused linear layer with SimPO loss.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
ignore_index: int = -100,
|
|
116
|
+
beta: float = 0.1,
|
|
117
|
+
alpha: float = 1.0,
|
|
118
|
+
label_smoothing: float = 0.0,
|
|
119
|
+
compute_nll_loss: bool = True,
|
|
120
|
+
compiled: bool = True,
|
|
121
|
+
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Args:
|
|
126
|
+
ignore_index (int): Index to ignore in the loss.
|
|
127
|
+
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
134
|
+
"""
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.ignore_index = ignore_index
|
|
137
|
+
self.beta = beta
|
|
138
|
+
self.alpha = alpha
|
|
139
|
+
self.label_smoothing = label_smoothing
|
|
140
|
+
self.compute_nll_loss = compute_nll_loss
|
|
141
|
+
self.compiled = compiled
|
|
142
|
+
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
152
|
+
return LigerFusedLinearSimPOFunction.apply(
|
|
153
|
+
_input,
|
|
154
|
+
lin_weight,
|
|
155
|
+
target,
|
|
156
|
+
bias,
|
|
157
|
+
self.ignore_index,
|
|
158
|
+
self.beta,
|
|
159
|
+
self.alpha,
|
|
160
|
+
self.label_smoothing,
|
|
161
|
+
self.compute_nll_loss,
|
|
162
|
+
self.compiled,
|
|
163
|
+
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
165
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import platform
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from importlib.metadata import version
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def print_env_report():
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
11
|
+
Usage:
|
|
12
|
+
```
|
|
13
|
+
python -m liger_kernel.env_report
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
print("Environment Report:")
|
|
18
|
+
print("-------------------")
|
|
19
|
+
print(f"Operating System: {platform.platform()}")
|
|
20
|
+
print(f"Python version: {sys.version.split()[0]}")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
print(f"Liger Kernel version: {version('liger-kernel')}")
|
|
24
|
+
except ImportError:
|
|
25
|
+
print("Liger Kernel: Not installed")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import torch
|
|
29
|
+
|
|
30
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
32
|
+
print(f"CUDA version: {cuda_version}")
|
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
|
34
|
+
print(f"HIP(ROCm) version: {hip_version}")
|
|
35
|
+
|
|
36
|
+
except ImportError:
|
|
37
|
+
print("PyTorch: Not installed")
|
|
38
|
+
print("CUDA version: Unable to query")
|
|
39
|
+
print("HIP(ROCm) version: Unable to query")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
import triton
|
|
43
|
+
|
|
44
|
+
print(f"Triton version: {triton.__version__}")
|
|
45
|
+
except ImportError:
|
|
46
|
+
print("Triton: Not installed")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
import transformers
|
|
50
|
+
|
|
51
|
+
print(f"Transformers version: {transformers.__version__}")
|
|
52
|
+
except ImportError:
|
|
53
|
+
print("Transformers: Not installed")
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
57
|
+
print(f"XPU version: {xpu_version}")
|
|
58
|
+
except ImportError:
|
|
59
|
+
print("XPU version: Unable to query")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
print_env_report()
|
liger_kernel/ops/__init__.py
CHANGED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Liger-Kernel operators with automatic vendor-specific replacement.
|
|
3
|
+
|
|
4
|
+
This module provides two ways to import operators:
|
|
5
|
+
|
|
6
|
+
1. Import from this package (recommended for Function classes):
|
|
7
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
8
|
+
|
|
9
|
+
This automatically uses vendor-specific implementation if available.
|
|
10
|
+
|
|
11
|
+
2. Import from submodules (for kernel functions or specific access):
|
|
12
|
+
from liger_kernel.ops.geglu import geglu_forward, geglu_backward
|
|
13
|
+
|
|
14
|
+
This always uses the default implementation (no auto-replacement).
|
|
15
|
+
|
|
16
|
+
The replacement mechanism:
|
|
17
|
+
1. Default implementations are imported from individual modules (e.g., geglu.py)
|
|
18
|
+
2. On module load, device is detected via infer_device()
|
|
19
|
+
3. If running on a supported vendor device (npu, xpu, etc.), the default
|
|
20
|
+
implementations are replaced with vendor-specific ones
|
|
21
|
+
4. All subsequent imports from this package get the replaced versions
|
|
22
|
+
|
|
23
|
+
Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
|
|
24
|
+
are NOT affected by the replacement mechanism.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# =============================================================================
|
|
28
|
+
# Import default implementations
|
|
29
|
+
# Both Function classes and kernel functions are imported here.
|
|
30
|
+
# All of these can be replaced by vendor-specific implementations.
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
|
|
34
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
|
|
35
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
|
|
36
|
+
from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
|
|
37
|
+
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
|
|
38
|
+
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
|
|
39
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
|
|
40
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
|
|
41
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
|
|
42
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
|
|
43
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
|
|
44
|
+
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
|
|
45
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
|
|
46
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
|
|
47
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
|
|
48
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
|
|
49
|
+
from liger_kernel.ops.geglu import geglu_backward # noqa: F401
|
|
50
|
+
from liger_kernel.ops.geglu import geglu_forward # noqa: F401
|
|
51
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
|
|
52
|
+
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
|
|
53
|
+
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
|
|
54
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
|
|
55
|
+
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
|
|
56
|
+
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
|
|
57
|
+
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
|
|
58
|
+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
|
|
59
|
+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
|
|
60
|
+
from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
|
|
61
|
+
from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
|
|
62
|
+
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
|
|
63
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
|
|
64
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
|
|
65
|
+
from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
|
|
66
|
+
from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
|
|
67
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
|
|
68
|
+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
|
|
69
|
+
from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
|
|
70
|
+
from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
|
|
71
|
+
from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
|
|
72
|
+
from liger_kernel.ops.rope import rope_backward # noqa: F401
|
|
73
|
+
from liger_kernel.ops.rope import rope_forward # noqa: F401
|
|
74
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
|
|
75
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
|
|
76
|
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
|
|
77
|
+
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
|
|
78
|
+
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
|
|
79
|
+
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
|
|
80
|
+
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
|
|
81
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
|
|
82
|
+
|
|
83
|
+
# NOTE: __all__ is intentionally NOT defined.
|
|
84
|
+
# - Import from this package (liger_kernel.ops) -> subject to vendor replacement
|
|
85
|
+
# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# =============================================================================
|
|
89
|
+
# Vendor-specific replacement logic
|
|
90
|
+
# =============================================================================
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _replace_with_vendor_ops():
|
|
94
|
+
"""
|
|
95
|
+
Replace/add vendor-specific operator implementations.
|
|
96
|
+
|
|
97
|
+
This function is called automatically on module load. It:
|
|
98
|
+
1. Detects the current device (cuda, npu, xpu, etc.)
|
|
99
|
+
2. Looks up the vendor for that device via VENDOR_REGISTRY
|
|
100
|
+
3. Loads and applies vendor-specific implementations
|
|
101
|
+
|
|
102
|
+
Vendor implementations should be placed in:
|
|
103
|
+
liger_kernel/ops/backends/_<vendor>/ops/
|
|
104
|
+
|
|
105
|
+
If the vendor module defines __all__, only those symbols are exported.
|
|
106
|
+
Otherwise, all public symbols (not starting with _) are auto-discovered.
|
|
107
|
+
|
|
108
|
+
Note: Vendor can both override existing ops AND add new vendor-specific ops.
|
|
109
|
+
"""
|
|
110
|
+
from liger_kernel.ops.backends import get_vendor_for_device
|
|
111
|
+
from liger_kernel.utils import infer_device
|
|
112
|
+
|
|
113
|
+
device = infer_device()
|
|
114
|
+
|
|
115
|
+
# Look up vendor info for this device
|
|
116
|
+
vendor_info = get_vendor_for_device(device)
|
|
117
|
+
if vendor_info is None:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
import importlib
|
|
122
|
+
|
|
123
|
+
vendor_ops = importlib.import_module(vendor_info.module_path)
|
|
124
|
+
|
|
125
|
+
# Get names to export: use __all__ if defined, otherwise auto-discover
|
|
126
|
+
names_to_export = getattr(vendor_ops, "__all__", None)
|
|
127
|
+
|
|
128
|
+
if names_to_export is None:
|
|
129
|
+
# Auto-discover: find all public symbols (classes and functions)
|
|
130
|
+
names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
|
|
131
|
+
|
|
132
|
+
# Replace or add to this module's globals
|
|
133
|
+
for name in names_to_export:
|
|
134
|
+
globals()[name] = getattr(vendor_ops, name)
|
|
135
|
+
|
|
136
|
+
except ImportError:
|
|
137
|
+
# Vendor module not available, use default implementations
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
_replace_with_vendor_ops()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Adding a New Vendor Backend
|
|
2
|
+
|
|
3
|
+
This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
|
|
4
|
+
|
|
5
|
+
## Concepts
|
|
6
|
+
|
|
7
|
+
- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
|
|
8
|
+
- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
|
|
9
|
+
- **VendorInfo**: Defines the mapping between vendor and device
|
|
10
|
+
|
|
11
|
+
## Directory Structure
|
|
12
|
+
|
|
13
|
+
```
|
|
14
|
+
backends/
|
|
15
|
+
├── README.md
|
|
16
|
+
├── __init__.py
|
|
17
|
+
├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
|
|
18
|
+
├── _ascend/ # Ascend (Huawei) vendor - supports NPU
|
|
19
|
+
│ ├── __init__.py # Registers VendorInfo for NPU
|
|
20
|
+
│ └── ops/
|
|
21
|
+
│ ├── __init__.py # Exports vendor-specific implementations
|
|
22
|
+
│ └── geglu.py # NPU-specific GEGLU implementation
|
|
23
|
+
└── _<vendor>/ # Your new vendor backend
|
|
24
|
+
└── ...
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## How It Works
|
|
28
|
+
|
|
29
|
+
1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
|
|
30
|
+
2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
|
|
31
|
+
3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
|
|
32
|
+
4. It detects the current device via `infer_device()` and looks up the vendor
|
|
33
|
+
5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
|
|
34
|
+
|
|
35
|
+
## Adding a New Vendor
|
|
36
|
+
|
|
37
|
+
### Step 1: Create Directory Structure
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
mkdir -p backends/_<vendor>/ops
|
|
41
|
+
touch backends/_<vendor>/__init__.py
|
|
42
|
+
touch backends/_<vendor>/ops/__init__.py
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
### Step 2: Register Your Vendor
|
|
46
|
+
|
|
47
|
+
In `backends/_<vendor>/__init__.py`, register your vendor:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
"""
|
|
51
|
+
<Vendor> backend for Liger-Kernel.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
|
|
55
|
+
|
|
56
|
+
register_vendor(
|
|
57
|
+
VendorInfo(
|
|
58
|
+
vendor="<vendor>",
|
|
59
|
+
device="<device>",
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
### Step 3: Ensure Device Detection Works
|
|
66
|
+
|
|
67
|
+
Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
def infer_device():
|
|
71
|
+
if torch.cuda.is_available():
|
|
72
|
+
return "cuda"
|
|
73
|
+
if is_npu_available():
|
|
74
|
+
return "npu"
|
|
75
|
+
# Add your device detection here
|
|
76
|
+
if is_<device>_available():
|
|
77
|
+
return "<device>"
|
|
78
|
+
return "cpu"
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Step 4: Implement Vendor-Specific Operators
|
|
82
|
+
|
|
83
|
+
Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
89
|
+
"""
|
|
90
|
+
Vendor-specific LigerGELUMulFunction implementation.
|
|
91
|
+
"""
|
|
92
|
+
@staticmethod
|
|
93
|
+
def forward(ctx, a, b):
|
|
94
|
+
# Your vendor-specific forward implementation
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def backward(ctx, dc):
|
|
99
|
+
# Your vendor-specific backward implementation
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
# Optional: vendor-specific kernel functions
|
|
103
|
+
def geglu_forward_vendor(a, b):
|
|
104
|
+
...
|
|
105
|
+
|
|
106
|
+
def geglu_backward_vendor(a, b, dc):
|
|
107
|
+
...
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Step 5: Export in `ops/__init__.py`
|
|
111
|
+
|
|
112
|
+
In `backends/_<vendor>/ops/__init__.py`, export your implementations:
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
"""
|
|
116
|
+
<Vendor>-specific operator implementations.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
from .<module> import (
|
|
120
|
+
LigerGELUMulFunction,
|
|
121
|
+
geglu_forward_vendor as geglu_forward, # Rename to match default API
|
|
122
|
+
geglu_backward_vendor as geglu_backward,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Explicitly declare what to export (recommended)
|
|
126
|
+
__all__ = [
|
|
127
|
+
"LigerGELUMulFunction",
|
|
128
|
+
"geglu_forward",
|
|
129
|
+
"geglu_backward",
|
|
130
|
+
]
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
## Key Points
|
|
134
|
+
|
|
135
|
+
### Incremental Override
|
|
136
|
+
|
|
137
|
+
You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
|
|
138
|
+
|
|
139
|
+
### Vendor-Specific Additions
|
|
140
|
+
|
|
141
|
+
Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
|
|
142
|
+
|
|
143
|
+
### Naming Convention
|
|
144
|
+
|
|
145
|
+
- Use the **same class/function names** as the default implementations for overrides
|
|
146
|
+
- This allows seamless replacement without changing user code
|
|
147
|
+
- Use `as` imports to rename if your internal naming differs
|
|
148
|
+
|
|
149
|
+
## Example: Ascend NPU Backend
|
|
150
|
+
|
|
151
|
+
See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pkgutil
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
|
|
5
|
+
from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
|
|
6
|
+
from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
|
|
7
|
+
from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
|
|
8
|
+
|
|
9
|
+
# Auto-import all _<vendor> subpackages to trigger registration
|
|
10
|
+
# Each vendor's __init__.py calls register_vendor() when imported
|
|
11
|
+
for _, modname, ispkg in pkgutil.iter_modules(__path__):
|
|
12
|
+
if ispkg and modname.startswith("_"):
|
|
13
|
+
importlib.import_module(f"{__name__}.{modname}")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ascend NPU operator implementations.
|
|
3
|
+
|
|
4
|
+
This module exports Ascend NPU-optimized implementations that will automatically
|
|
5
|
+
replace the default implementations when running on NPU devices.
|
|
6
|
+
|
|
7
|
+
Both Function classes and kernel functions can be exported here.
|
|
8
|
+
|
|
9
|
+
To add a new operator:
|
|
10
|
+
1. Create the implementation file (e.g., rms_norm.py)
|
|
11
|
+
2. Import the Function class and/or kernel functions here
|
|
12
|
+
3. Optionally add to __all__ for explicit control
|
|
13
|
+
|
|
14
|
+
If __all__ is not defined, all public symbols will be auto-discovered.
|
|
15
|
+
"""
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vendor registry for Liger-Kernel multi-backend support.
|
|
3
|
+
|
|
4
|
+
This module defines VendorInfo and the registry for vendor registration.
|
|
5
|
+
Each vendor registers itself by calling register_vendor() in its __init__.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
# Dynamically get backends package path to avoid hardcoding
|
|
12
|
+
_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class VendorInfo:
|
|
17
|
+
"""
|
|
18
|
+
Information about a chip vendor and its supported device.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
|
|
22
|
+
device: Device type this vendor supports (e.g., "npu", "xpu")
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
vendor: str
|
|
26
|
+
device: str
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def module_path(self) -> str:
|
|
30
|
+
"""Auto-generated module path based on vendor name."""
|
|
31
|
+
return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Registry mapping device types to their vendor info
|
|
35
|
+
# Vendors register themselves via register_vendor()
|
|
36
|
+
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def register_vendor(vendor_info: VendorInfo) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Register a vendor's info in the global registry.
|
|
42
|
+
|
|
43
|
+
This should be called in each vendor's __init__.py to register itself.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
vendor_info: VendorInfo instance to register
|
|
47
|
+
"""
|
|
48
|
+
VENDOR_REGISTRY[vendor_info.device] = vendor_info
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
|
|
52
|
+
"""
|
|
53
|
+
Get the VendorInfo for a given device type.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
device: Device type (e.g., "npu", "xpu")
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
VendorInfo if found, None otherwise
|
|
60
|
+
"""
|
|
61
|
+
return VENDOR_REGISTRY.get(device)
|