liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/functional.py +62 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +62 -98
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/monkey_patch.py +304 -70
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from torch.nn.modules.utils import _pair
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerMultiTokenAttention(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Multi-Token Attention:
|
|
14
|
+
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
|
|
15
|
+
|
|
16
|
+
Reference: https://arxiv.org/pdf/2504.00927
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
in_channels: int,
|
|
22
|
+
out_channels: int,
|
|
23
|
+
kernel_size: int,
|
|
24
|
+
stride: int = 1,
|
|
25
|
+
padding: int = 0,
|
|
26
|
+
dilation: int = 1,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
bias: bool = True,
|
|
29
|
+
sparse: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.in_channels = in_channels
|
|
33
|
+
self.out_channels = out_channels
|
|
34
|
+
self.kernel_size = _pair(kernel_size)
|
|
35
|
+
self.stride = _pair(stride)
|
|
36
|
+
self.padding = _pair(padding)
|
|
37
|
+
self.dilation = _pair(dilation)
|
|
38
|
+
self.groups = groups
|
|
39
|
+
self.sparse = sparse
|
|
40
|
+
|
|
41
|
+
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
|
|
42
|
+
if bias:
|
|
43
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
44
|
+
else:
|
|
45
|
+
self.register_parameter("bias", None)
|
|
46
|
+
|
|
47
|
+
self.reset_parameters()
|
|
48
|
+
|
|
49
|
+
def reset_parameters(self):
|
|
50
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
51
|
+
if self.bias is not None:
|
|
52
|
+
nn.init.zeros_(self.bias)
|
|
53
|
+
|
|
54
|
+
def forward(self, scores: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
return LigerMultiTokenAttentionFunction.apply(
|
|
56
|
+
scores,
|
|
57
|
+
self.weight,
|
|
58
|
+
self.bias,
|
|
59
|
+
self.stride,
|
|
60
|
+
self.padding,
|
|
61
|
+
self.dilation,
|
|
62
|
+
self.groups,
|
|
63
|
+
self.sparse,
|
|
64
|
+
)
|
|
@@ -13,6 +13,7 @@ class LigerRMSNorm(nn.Module):
|
|
|
13
13
|
casting_mode="llama",
|
|
14
14
|
init_fn="ones",
|
|
15
15
|
in_place=True,
|
|
16
|
+
row_mode=None,
|
|
16
17
|
):
|
|
17
18
|
super().__init__()
|
|
18
19
|
assert init_fn in [
|
|
@@ -20,11 +21,12 @@ class LigerRMSNorm(nn.Module):
|
|
|
20
21
|
"zeros",
|
|
21
22
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
22
23
|
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
23
|
-
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
|
24
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
24
25
|
eps,
|
|
25
26
|
offset,
|
|
26
27
|
casting_mode,
|
|
27
28
|
in_place,
|
|
29
|
+
row_mode,
|
|
28
30
|
)
|
|
29
31
|
|
|
30
32
|
def forward(self, hidden_states):
|
|
@@ -35,9 +37,43 @@ class LigerRMSNorm(nn.Module):
|
|
|
35
37
|
self.offset,
|
|
36
38
|
self.casting_mode,
|
|
37
39
|
self.in_place,
|
|
40
|
+
self.row_mode,
|
|
38
41
|
)
|
|
39
42
|
|
|
40
43
|
def extra_repr(self):
|
|
41
|
-
return (
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
|
48
|
+
def __init__(
|
|
49
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
|
50
|
+
):
|
|
51
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
|
55
|
+
def __init__(
|
|
56
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
57
|
+
):
|
|
58
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
62
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
65
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
|
69
|
+
def __init__(
|
|
70
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
71
|
+
):
|
|
72
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
|
76
|
+
def __init__(
|
|
77
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
78
|
+
):
|
|
79
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSoftmax(nn.Module):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
def forward(self, x: torch.Tensor):
|
|
12
|
+
return LigerSoftmaxFunction.apply(x)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -33,7 +33,7 @@ License-File: NOTICE
|
|
|
33
33
|
Requires-Dist: torch>=2.1.2
|
|
34
34
|
Requires-Dist: triton>=2.3.1
|
|
35
35
|
Provides-Extra: dev
|
|
36
|
-
Requires-Dist: transformers>=4.
|
|
36
|
+
Requires-Dist: transformers>=4.49.0; extra == "dev"
|
|
37
37
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
38
38
|
Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
|
39
39
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
|
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
|
45
45
|
Requires-Dist: seaborn; extra == "dev"
|
|
46
46
|
Requires-Dist: mkdocs; extra == "dev"
|
|
47
47
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
48
|
+
Requires-Dist: torchvision>=0.20; extra == "dev"
|
|
48
49
|
Dynamic: license-file
|
|
49
50
|
Dynamic: provides-extra
|
|
50
51
|
Dynamic: requires-dist
|
|
@@ -114,6 +115,8 @@ Dynamic: requires-dist
|
|
|
114
115
|
|
|
115
116
|
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
|
|
116
117
|
|
|
118
|
+
You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
|
|
119
|
+
|
|
117
120
|
## Supercharge Your Model with Liger Kernel
|
|
118
121
|
|
|
119
122
|

|
|
@@ -290,6 +293,7 @@ loss.backward()
|
|
|
290
293
|
|
|
291
294
|
| **Model** | **API** | **Supported Operations** |
|
|
292
295
|
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
|
|
296
|
+
| Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
293
297
|
| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
294
298
|
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
295
299
|
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -326,6 +330,8 @@ loss.backward()
|
|
|
326
330
|
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
|
|
327
331
|
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
|
|
328
332
|
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
|
|
333
|
+
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
|
|
334
|
+
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
|
|
329
335
|
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
|
|
330
336
|
|
|
331
337
|
|
|
@@ -2,10 +2,11 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
|
3
3
|
liger_kernel/utils.py,sha256=BQleeZWHSZPNuPcYcoZTOp1kcNEZONZilPP5-AmjgWI,2024
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=
|
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=J5_jNnzZ4gZmA38W5f_4oab7xMoNk1Xy-yh3X_Xlf-s,714
|
|
6
|
+
liger_kernel/chunked_loss/cosine_similarity_loss.py,sha256=pZ07OQ6RI-c8uk96tDRlUXdt31-da7yWhfwircZlKRw,4198
|
|
6
7
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
|
7
8
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=tapMiNdI8_ufW55iG0Ud4dmiW39gu1DzlvtoOCHrdGg,6259
|
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256
|
|
9
|
+
liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
|
|
9
10
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
|
10
11
|
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
|
|
11
12
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
|
|
@@ -17,74 +18,80 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmM
|
|
|
17
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
|
18
19
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
20
|
liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
|
|
20
|
-
liger_kernel/ops/dyt.py,sha256=
|
|
21
|
+
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
|
21
22
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
|
|
22
23
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
|
23
|
-
liger_kernel/ops/
|
|
24
|
+
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
|
25
|
+
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
|
24
26
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
|
25
27
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
|
26
28
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
|
27
29
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
|
28
30
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
|
31
|
+
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
29
32
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
30
|
-
liger_kernel/ops/rms_norm.py,sha256
|
|
33
|
+
liger_kernel/ops/rms_norm.py,sha256=-rcgHwWCxlA-Syec2XhdW4jfOeCDt2r7qwjslgXFYDU,18865
|
|
31
34
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
|
32
|
-
liger_kernel/ops/
|
|
33
|
-
liger_kernel/ops/
|
|
35
|
+
liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
|
|
36
|
+
liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
|
|
37
|
+
liger_kernel/ops/swiglu.py,sha256=D7nd4u_LInwsIRNCDdY77lqnTz8-W5dJrpEAt8zEO_A,3033
|
|
34
38
|
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
35
39
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
|
36
40
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
37
41
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
38
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
42
|
+
liger_kernel/transformers/__init__.py,sha256=mWMEhOabqUkPimMOmkg9DawnO-vL9u_u-N4iIqfNZeg,7259
|
|
39
43
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
40
44
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
|
41
45
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
|
42
46
|
liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
|
|
43
|
-
liger_kernel/transformers/functional.py,sha256=
|
|
47
|
+
liger_kernel/transformers/functional.py,sha256=7Emw7D6VPMg8hfasC33NiolvKmQVF1gV6VayKQCEWJM,7446
|
|
44
48
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
|
45
49
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
|
50
|
+
liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
|
|
46
51
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
|
47
|
-
liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
|
|
48
52
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
|
49
53
|
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
|
50
54
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
|
51
55
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
|
52
56
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
|
53
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
57
|
+
liger_kernel/transformers/monkey_patch.py,sha256=W7KgJN-rrLZS3pRZ5debO_dSN7zddPegKjqOIP39wR0,85856
|
|
58
|
+
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
|
54
59
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
|
55
|
-
liger_kernel/transformers/rms_norm.py,sha256=
|
|
60
|
+
liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
|
|
56
61
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
|
62
|
+
liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
|
|
57
63
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
|
58
64
|
liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
|
|
59
65
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
|
60
66
|
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
|
61
67
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
|
62
68
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
-
liger_kernel/transformers/model/gemma.py,sha256=
|
|
64
|
-
liger_kernel/transformers/model/gemma2.py,sha256=
|
|
65
|
-
liger_kernel/transformers/model/gemma3.py,sha256=
|
|
66
|
-
liger_kernel/transformers/model/glm4.py,sha256=
|
|
67
|
-
liger_kernel/transformers/model/llama.py,sha256=
|
|
68
|
-
liger_kernel/transformers/model/
|
|
69
|
+
liger_kernel/transformers/model/gemma.py,sha256=mNX-mIwV6jI4zfbrUHp0C468pOmjzsL7mjXipGt-eS0,10007
|
|
70
|
+
liger_kernel/transformers/model/gemma2.py,sha256=R_JFPyWTk7RyA7D05ZiIaNO5pX8gWcvfWf-6rdCRMxs,11296
|
|
71
|
+
liger_kernel/transformers/model/gemma3.py,sha256=XbwoqOSPmtS0BPHgT8jZftTzplmiAicgBa6ocNcet8o,12800
|
|
72
|
+
liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
|
|
73
|
+
liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
|
|
74
|
+
liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
|
|
75
|
+
liger_kernel/transformers/model/llava.py,sha256=bLCioday_SOm69ogMDBhy_4UsVkH2-BSl93-EXY6-7I,15076
|
|
69
76
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
|
70
|
-
liger_kernel/transformers/model/mistral.py,sha256=
|
|
71
|
-
liger_kernel/transformers/model/mixtral.py,sha256=
|
|
72
|
-
liger_kernel/transformers/model/mllama.py,sha256=
|
|
73
|
-
liger_kernel/transformers/model/olmo2.py,sha256=
|
|
74
|
-
liger_kernel/transformers/model/paligemma.py,sha256=
|
|
75
|
-
liger_kernel/transformers/model/phi3.py,sha256=
|
|
76
|
-
liger_kernel/transformers/model/qwen2.py,sha256=
|
|
77
|
-
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=
|
|
78
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
|
79
|
-
liger_kernel/transformers/model/qwen3.py,sha256=
|
|
80
|
-
liger_kernel/transformers/model/qwen3_moe.py,sha256=
|
|
77
|
+
liger_kernel/transformers/model/mistral.py,sha256=syYNL8dLThX2-4uC13Lu0krEZ5zw3InviDUR3AJmc-I,5500
|
|
78
|
+
liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
|
|
79
|
+
liger_kernel/transformers/model/mllama.py,sha256=my29NXk-p6ckQaP8qDIN8e318yI_9mQZHt38MV3SqLY,11280
|
|
80
|
+
liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6BiGp13g7Uq9E,5259
|
|
81
|
+
liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
|
|
82
|
+
liger_kernel/transformers/model/phi3.py,sha256=zAzBVNOA16B16yy2HWsEgOMHhLoYkpWOWPgBT4z95WI,10655
|
|
83
|
+
liger_kernel/transformers/model/qwen2.py,sha256=3fpOTEOkniQmkCfN1KUa3KhseHJVzhj2Ht9FdYPUy-E,9962
|
|
84
|
+
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=zEVVwotCXnAm3RRc8-1Nc8uitSWrwW4B9dYY2uOZDwg,6331
|
|
85
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=5vK-vtCDpKZ2w33xYp2BS8kQYWUbKMqaiKvQcI27Mss,5884
|
|
86
|
+
liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
|
|
87
|
+
liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
|
|
81
88
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
|
82
89
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
83
90
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
84
91
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
85
|
-
liger_kernel-0.
|
|
86
|
-
liger_kernel-0.
|
|
87
|
-
liger_kernel-0.
|
|
88
|
-
liger_kernel-0.
|
|
89
|
-
liger_kernel-0.
|
|
90
|
-
liger_kernel-0.
|
|
92
|
+
liger_kernel-0.6.0.dist-info/licenses/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
93
|
+
liger_kernel-0.6.0.dist-info/licenses/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
94
|
+
liger_kernel-0.6.0.dist-info/METADATA,sha256=YQs0IFuj3o4GPiiDJ6K2s_HqIIWTv8SvQLVU_tPRwGY,24578
|
|
95
|
+
liger_kernel-0.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
96
|
+
liger_kernel-0.6.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
97
|
+
liger_kernel-0.6.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
from .rms_norm import LigerRMSNorm
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
5
|
-
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
6
|
-
|
|
7
|
-
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
8
|
-
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|