liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,165 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ chosen_logps,
11
+ rejected_logps,
12
+ full_target,
13
+ beta=0.1,
14
+ gamma=0.5,
15
+ label_smoothing=0.0,
16
+ ):
17
+ """
18
+ Paper: https://arxiv.org/pdf/2405.14734
19
+
20
+ Formula:
21
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
22
+
23
+ Where:
24
+ - π_θ(y|x): Policy (model) probability
25
+ - y_w: Chosen sequence
26
+ - y_l: Rejected sequence
27
+ - |y_w|, |y_l|: Sequence lengths
28
+ - σ: Sigmoid function
29
+ - β: beta weight
30
+ - γ: gemma margin term
31
+
32
+ Args:
33
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
34
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
35
+ full_target: Non chunked full target tensor
36
+ beta (float): beta weight
37
+ gamma (float): gemma margin term
38
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
39
+ """
40
+ logits = beta * (chosen_logps - rejected_logps) - gamma
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
44
+
45
+ chosen_rewards = beta * chosen_logps
46
+ rejected_rewards = beta * rejected_logps
47
+
48
+ return loss, chosen_rewards, rejected_rewards
49
+
50
+ @classmethod
51
+ def forward(
52
+ cls,
53
+ ctx,
54
+ _input,
55
+ weight,
56
+ target,
57
+ bias=None,
58
+ ignore_index=-100,
59
+ beta=0.1,
60
+ alpha=1.0,
61
+ label_smoothing=0.0,
62
+ compute_nll_loss=False,
63
+ compiled=True,
64
+ gamma=0.5,
65
+ chunk_size=1,
66
+ ):
67
+ """
68
+ Fused linear layer with SimPO loss.
69
+ Args:
70
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
71
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
72
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
73
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
74
+ ignore_index (int): Index to ignore in loss computation
75
+ beta (float): Weight for the odds ratio loss
76
+ alpha (float): Weight for the alpha parameter
77
+ label_smoothing (float): Label smoothing factor
78
+ compute_nll_loss (bool): Whether to compute the NLL loss
79
+ compiled (bool): Whether to use torch compile
80
+ gamma (float): Weight for the gamma parameter
81
+ chunk_size (int): Size of chunks for processing
82
+ Returns:
83
+ torch.Tensor: Computed loss
84
+ """
85
+ return super().forward(
86
+ cls=cls,
87
+ ctx=ctx,
88
+ _input=_input,
89
+ weight=weight,
90
+ target=target,
91
+ bias=bias,
92
+ ignore_index=ignore_index,
93
+ alpha=alpha,
94
+ beta=beta,
95
+ label_smoothing=label_smoothing,
96
+ compute_nll_loss=compute_nll_loss,
97
+ compiled=compiled,
98
+ gamma=gamma,
99
+ chunk_size=chunk_size,
100
+ )
101
+
102
+ @staticmethod
103
+ def backward(ctx, *grad_output):
104
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
105
+ return *grads, None, None, None, None, None, None, None, None
106
+
107
+
108
+ class LigerFusedLinearSimPOLoss(torch.nn.Module):
109
+ """
110
+ Fused linear layer with SimPO loss.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ ignore_index: int = -100,
116
+ beta: float = 0.1,
117
+ alpha: float = 1.0,
118
+ label_smoothing: float = 0.0,
119
+ compute_nll_loss: bool = True,
120
+ compiled: bool = True,
121
+ gamma: float = 0.5,
122
+ chunk_size: int = 1,
123
+ ):
124
+ """
125
+ Args:
126
+ ignore_index (int): Index to ignore in the loss.
127
+ beta (float): Weight for the odds ratio loss.
128
+ alpha (float): Weight for the alpha parameter.
129
+ label_smoothing (float): Label smoothing factor.
130
+ compute_nll_loss (bool): Whether to compute the NLL loss.
131
+ compiled (bool): Whether to use the torch compiled kernel.
132
+ gamma (float): Weight for the gamma parameter.
133
+ chunk_size (int): Size of chunks for processing.
134
+ """
135
+ super().__init__()
136
+ self.ignore_index = ignore_index
137
+ self.beta = beta
138
+ self.alpha = alpha
139
+ self.label_smoothing = label_smoothing
140
+ self.compute_nll_loss = compute_nll_loss
141
+ self.compiled = compiled
142
+ self.gamma = gamma
143
+ self.chunk_size = chunk_size
144
+
145
+ def forward(
146
+ self,
147
+ lin_weight,
148
+ _input,
149
+ target,
150
+ bias=None,
151
+ ):
152
+ return LigerFusedLinearSimPOFunction.apply(
153
+ _input,
154
+ lin_weight,
155
+ target,
156
+ bias,
157
+ self.ignore_index,
158
+ self.beta,
159
+ self.alpha,
160
+ self.label_smoothing,
161
+ self.compute_nll_loss,
162
+ self.compiled,
163
+ self.gamma,
164
+ self.chunk_size,
165
+ )
@@ -0,0 +1,63 @@
1
+ import platform
2
+ import sys
3
+
4
+ from importlib.metadata import version
5
+
6
+
7
+ def print_env_report():
8
+ """
9
+
10
+ Prints a report of the environment. Useful for debugging and reproducibility.
11
+ Usage:
12
+ ```
13
+ python -m liger_kernel.env_report
14
+ ```
15
+
16
+ """
17
+ print("Environment Report:")
18
+ print("-------------------")
19
+ print(f"Operating System: {platform.platform()}")
20
+ print(f"Python version: {sys.version.split()[0]}")
21
+
22
+ try:
23
+ print(f"Liger Kernel version: {version('liger-kernel')}")
24
+ except ImportError:
25
+ print("Liger Kernel: Not installed")
26
+
27
+ try:
28
+ import torch
29
+
30
+ print(f"PyTorch version: {torch.__version__}")
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
32
+ print(f"CUDA version: {cuda_version}")
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
34
+ print(f"HIP(ROCm) version: {hip_version}")
35
+
36
+ except ImportError:
37
+ print("PyTorch: Not installed")
38
+ print("CUDA version: Unable to query")
39
+ print("HIP(ROCm) version: Unable to query")
40
+
41
+ try:
42
+ import triton
43
+
44
+ print(f"Triton version: {triton.__version__}")
45
+ except ImportError:
46
+ print("Triton: Not installed")
47
+
48
+ try:
49
+ import transformers
50
+
51
+ print(f"Transformers version: {transformers.__version__}")
52
+ except ImportError:
53
+ print("Transformers: Not installed")
54
+
55
+ try:
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
57
+ print(f"XPU version: {xpu_version}")
58
+ except ImportError:
59
+ print("XPU version: Unable to query")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ print_env_report()
@@ -0,0 +1,141 @@
1
+ """
2
+ Liger-Kernel operators with automatic vendor-specific replacement.
3
+
4
+ This module provides two ways to import operators:
5
+
6
+ 1. Import from this package (recommended for Function classes):
7
+ from liger_kernel.ops import LigerGELUMulFunction
8
+
9
+ This automatically uses vendor-specific implementation if available.
10
+
11
+ 2. Import from submodules (for kernel functions or specific access):
12
+ from liger_kernel.ops.geglu import geglu_forward, geglu_backward
13
+
14
+ This always uses the default implementation (no auto-replacement).
15
+
16
+ The replacement mechanism:
17
+ 1. Default implementations are imported from individual modules (e.g., geglu.py)
18
+ 2. On module load, device is detected via infer_device()
19
+ 3. If running on a supported vendor device (npu, xpu, etc.), the default
20
+ implementations are replaced with vendor-specific ones
21
+ 4. All subsequent imports from this package get the replaced versions
22
+
23
+ Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
24
+ are NOT affected by the replacement mechanism.
25
+ """
26
+
27
+ # =============================================================================
28
+ # Import default implementations
29
+ # Both Function classes and kernel functions are imported here.
30
+ # All of these can be replaced by vendor-specific implementations.
31
+ # =============================================================================
32
+
33
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
34
+ from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
35
+ from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
36
+ from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
37
+ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
38
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
39
+ from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
40
+ from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
41
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
42
+ from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
43
+ from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
44
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
45
+ from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
46
+ from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
47
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
48
+ from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
49
+ from liger_kernel.ops.geglu import geglu_backward # noqa: F401
50
+ from liger_kernel.ops.geglu import geglu_forward # noqa: F401
51
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
52
+ from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
53
+ from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
54
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
55
+ from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
56
+ from liger_kernel.ops.jsd import jsd_backward # noqa: F401
57
+ from liger_kernel.ops.jsd import jsd_forward # noqa: F401
58
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
59
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
60
+ from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
61
+ from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
62
+ from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
63
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
64
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
65
+ from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
66
+ from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
67
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
68
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
69
+ from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
70
+ from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
71
+ from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
72
+ from liger_kernel.ops.rope import rope_backward # noqa: F401
73
+ from liger_kernel.ops.rope import rope_forward # noqa: F401
74
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
75
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
76
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
77
+ from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
78
+ from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
79
+ from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
80
+ from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
81
+ from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
82
+
83
+ # NOTE: __all__ is intentionally NOT defined.
84
+ # - Import from this package (liger_kernel.ops) -> subject to vendor replacement
85
+ # - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
86
+
87
+
88
+ # =============================================================================
89
+ # Vendor-specific replacement logic
90
+ # =============================================================================
91
+
92
+
93
+ def _replace_with_vendor_ops():
94
+ """
95
+ Replace/add vendor-specific operator implementations.
96
+
97
+ This function is called automatically on module load. It:
98
+ 1. Detects the current device (cuda, npu, xpu, etc.)
99
+ 2. Looks up the vendor for that device via VENDOR_REGISTRY
100
+ 3. Loads and applies vendor-specific implementations
101
+
102
+ Vendor implementations should be placed in:
103
+ liger_kernel/ops/backends/_<vendor>/ops/
104
+
105
+ If the vendor module defines __all__, only those symbols are exported.
106
+ Otherwise, all public symbols (not starting with _) are auto-discovered.
107
+
108
+ Note: Vendor can both override existing ops AND add new vendor-specific ops.
109
+ """
110
+ from liger_kernel.ops.backends import get_vendor_for_device
111
+ from liger_kernel.utils import infer_device
112
+
113
+ device = infer_device()
114
+
115
+ # Look up vendor info for this device
116
+ vendor_info = get_vendor_for_device(device)
117
+ if vendor_info is None:
118
+ return
119
+
120
+ try:
121
+ import importlib
122
+
123
+ vendor_ops = importlib.import_module(vendor_info.module_path)
124
+
125
+ # Get names to export: use __all__ if defined, otherwise auto-discover
126
+ names_to_export = getattr(vendor_ops, "__all__", None)
127
+
128
+ if names_to_export is None:
129
+ # Auto-discover: find all public symbols (classes and functions)
130
+ names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
131
+
132
+ # Replace or add to this module's globals
133
+ for name in names_to_export:
134
+ globals()[name] = getattr(vendor_ops, name)
135
+
136
+ except ImportError:
137
+ # Vendor module not available, use default implementations
138
+ pass
139
+
140
+
141
+ _replace_with_vendor_ops()
@@ -0,0 +1,151 @@
1
+ # Adding a New Vendor Backend
2
+
3
+ This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
4
+
5
+ ## Concepts
6
+
7
+ - **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
8
+ - **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
9
+ - **VendorInfo**: Defines the mapping between vendor and device
10
+
11
+ ## Directory Structure
12
+
13
+ ```
14
+ backends/
15
+ ├── README.md
16
+ ├── __init__.py
17
+ ├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
18
+ ├── _ascend/ # Ascend (Huawei) vendor - supports NPU
19
+ │ ├── __init__.py # Registers VendorInfo for NPU
20
+ │ └── ops/
21
+ │ ├── __init__.py # Exports vendor-specific implementations
22
+ │ └── geglu.py # NPU-specific GEGLU implementation
23
+ └── _<vendor>/ # Your new vendor backend
24
+ └── ...
25
+ ```
26
+
27
+ ## How It Works
28
+
29
+ 1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
30
+ 2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
31
+ 3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
32
+ 4. It detects the current device via `infer_device()` and looks up the vendor
33
+ 5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
34
+
35
+ ## Adding a New Vendor
36
+
37
+ ### Step 1: Create Directory Structure
38
+
39
+ ```bash
40
+ mkdir -p backends/_<vendor>/ops
41
+ touch backends/_<vendor>/__init__.py
42
+ touch backends/_<vendor>/ops/__init__.py
43
+ ```
44
+
45
+ ### Step 2: Register Your Vendor
46
+
47
+ In `backends/_<vendor>/__init__.py`, register your vendor:
48
+
49
+ ```python
50
+ """
51
+ <Vendor> backend for Liger-Kernel.
52
+ """
53
+
54
+ from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
55
+
56
+ register_vendor(
57
+ VendorInfo(
58
+ vendor="<vendor>",
59
+ device="<device>",
60
+ )
61
+ )
62
+ ```
63
+
64
+
65
+ ### Step 3: Ensure Device Detection Works
66
+
67
+ Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
68
+
69
+ ```python
70
+ def infer_device():
71
+ if torch.cuda.is_available():
72
+ return "cuda"
73
+ if is_npu_available():
74
+ return "npu"
75
+ # Add your device detection here
76
+ if is_<device>_available():
77
+ return "<device>"
78
+ return "cpu"
79
+ ```
80
+
81
+ ### Step 4: Implement Vendor-Specific Operators
82
+
83
+ Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
84
+
85
+ ```python
86
+ import torch
87
+
88
+ class LigerGELUMulFunction(torch.autograd.Function):
89
+ """
90
+ Vendor-specific LigerGELUMulFunction implementation.
91
+ """
92
+ @staticmethod
93
+ def forward(ctx, a, b):
94
+ # Your vendor-specific forward implementation
95
+ ...
96
+
97
+ @staticmethod
98
+ def backward(ctx, dc):
99
+ # Your vendor-specific backward implementation
100
+ ...
101
+
102
+ # Optional: vendor-specific kernel functions
103
+ def geglu_forward_vendor(a, b):
104
+ ...
105
+
106
+ def geglu_backward_vendor(a, b, dc):
107
+ ...
108
+ ```
109
+
110
+ ### Step 5: Export in `ops/__init__.py`
111
+
112
+ In `backends/_<vendor>/ops/__init__.py`, export your implementations:
113
+
114
+ ```python
115
+ """
116
+ <Vendor>-specific operator implementations.
117
+ """
118
+
119
+ from .<module> import (
120
+ LigerGELUMulFunction,
121
+ geglu_forward_vendor as geglu_forward, # Rename to match default API
122
+ geglu_backward_vendor as geglu_backward,
123
+ )
124
+
125
+ # Explicitly declare what to export (recommended)
126
+ __all__ = [
127
+ "LigerGELUMulFunction",
128
+ "geglu_forward",
129
+ "geglu_backward",
130
+ ]
131
+ ```
132
+
133
+ ## Key Points
134
+
135
+ ### Incremental Override
136
+
137
+ You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
138
+
139
+ ### Vendor-Specific Additions
140
+
141
+ Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
142
+
143
+ ### Naming Convention
144
+
145
+ - Use the **same class/function names** as the default implementations for overrides
146
+ - This allows seamless replacement without changing user code
147
+ - Use `as` imports to rename if your internal naming differs
148
+
149
+ ## Example: Ascend NPU Backend
150
+
151
+ See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
@@ -0,0 +1,13 @@
1
+ import importlib
2
+ import pkgutil
3
+
4
+ from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
5
+ from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
6
+ from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
7
+ from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
8
+
9
+ # Auto-import all _<vendor> subpackages to trigger registration
10
+ # Each vendor's __init__.py calls register_vendor() when imported
11
+ for _, modname, ispkg in pkgutil.iter_modules(__path__):
12
+ if ispkg and modname.startswith("_"):
13
+ importlib.import_module(f"{__name__}.{modname}")
@@ -0,0 +1,5 @@
1
+ from liger_kernel.ops.backends.registry import VendorInfo
2
+ from liger_kernel.ops.backends.registry import register_vendor
3
+
4
+ # Register Ascend vendor for NPU device
5
+ register_vendor(VendorInfo(vendor="ascend", device="npu"))
@@ -0,0 +1,15 @@
1
+ """
2
+ Ascend NPU operator implementations.
3
+
4
+ This module exports Ascend NPU-optimized implementations that will automatically
5
+ replace the default implementations when running on NPU devices.
6
+
7
+ Both Function classes and kernel functions can be exported here.
8
+
9
+ To add a new operator:
10
+ 1. Create the implementation file (e.g., rms_norm.py)
11
+ 2. Import the Function class and/or kernel functions here
12
+ 3. Optionally add to __all__ for explicit control
13
+
14
+ If __all__ is not defined, all public symbols will be auto-discovered.
15
+ """
@@ -0,0 +1,61 @@
1
+ """
2
+ Vendor registry for Liger-Kernel multi-backend support.
3
+
4
+ This module defines VendorInfo and the registry for vendor registration.
5
+ Each vendor registers itself by calling register_vendor() in its __init__.py.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ # Dynamically get backends package path to avoid hardcoding
12
+ _BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
13
+
14
+
15
+ @dataclass
16
+ class VendorInfo:
17
+ """
18
+ Information about a chip vendor and its supported device.
19
+
20
+ Attributes:
21
+ vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
22
+ device: Device type this vendor supports (e.g., "npu", "xpu")
23
+ """
24
+
25
+ vendor: str
26
+ device: str
27
+
28
+ @property
29
+ def module_path(self) -> str:
30
+ """Auto-generated module path based on vendor name."""
31
+ return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
32
+
33
+
34
+ # Registry mapping device types to their vendor info
35
+ # Vendors register themselves via register_vendor()
36
+ VENDOR_REGISTRY: dict[str, VendorInfo] = {}
37
+
38
+
39
+ def register_vendor(vendor_info: VendorInfo) -> None:
40
+ """
41
+ Register a vendor's info in the global registry.
42
+
43
+ This should be called in each vendor's __init__.py to register itself.
44
+
45
+ Args:
46
+ vendor_info: VendorInfo instance to register
47
+ """
48
+ VENDOR_REGISTRY[vendor_info.device] = vendor_info
49
+
50
+
51
+ def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
52
+ """
53
+ Get the VendorInfo for a given device type.
54
+
55
+ Args:
56
+ device: Device type (e.g., "npu", "xpu")
57
+
58
+ Returns:
59
+ VendorInfo if found, None otherwise
60
+ """
61
+ return VENDOR_REGISTRY.get(device)