liger-kernel-nightly 0.5.9.dev20250515034325__py3-none-any.whl → 0.5.9.dev20250515065336__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/RECORD +9 -7
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.9.dev20250515034325.dist-info → liger_kernel_nightly-0.5.9.dev20250515065336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
7
|
+
|
8
|
+
|
9
|
+
@triton.jit
|
10
|
+
def _sparsemax_forward_kernel(
|
11
|
+
x_ptr,
|
12
|
+
x_stride_row,
|
13
|
+
sorted_x_ptr,
|
14
|
+
sorted_x_stride_row,
|
15
|
+
o_ptr,
|
16
|
+
o_stride_row,
|
17
|
+
n_cols,
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
19
|
+
num_warps: tl.constexpr,
|
20
|
+
):
|
21
|
+
pid_row = tl.program_id(0)
|
22
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
23
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
24
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
25
|
+
|
26
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
27
|
+
mask = offs < n_cols
|
28
|
+
|
29
|
+
z_sorted_block = tl.load(
|
30
|
+
ptr_sorted_x_data_row + offs,
|
31
|
+
mask=mask,
|
32
|
+
other=-float("inf"),
|
33
|
+
cache_modifier=".ca",
|
34
|
+
).to(tl.float32)
|
35
|
+
|
36
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
37
|
+
cssv = tl.cumsum(z_valid, 0)
|
38
|
+
|
39
|
+
r = (offs + 1).to(tl.float32)
|
40
|
+
safe_r = tl.where(mask, r, 1.0)
|
41
|
+
|
42
|
+
t_vec = (cssv - 1.0) / safe_r
|
43
|
+
|
44
|
+
support = (z_sorted_block > t_vec) & mask
|
45
|
+
|
46
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
47
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
48
|
+
k = k_clamped_int.to(tl.float32)
|
49
|
+
|
50
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
51
|
+
|
52
|
+
tau = (s - 1.0) / k
|
53
|
+
|
54
|
+
x_block = tl.load(
|
55
|
+
ptr_x_data_row + offs,
|
56
|
+
mask=mask,
|
57
|
+
other=0.0,
|
58
|
+
cache_modifier=".ca",
|
59
|
+
).to(tl.float32)
|
60
|
+
|
61
|
+
y = tl.maximum(x_block - tau, 0.0)
|
62
|
+
|
63
|
+
tl.store(
|
64
|
+
ptr_output_row + offs,
|
65
|
+
y.to(ptr_output_row.dtype.element_ty),
|
66
|
+
mask=mask,
|
67
|
+
cache_modifier=".cs",
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@triton.jit
|
72
|
+
def _sparsemax_backward_kernel(
|
73
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
74
|
+
):
|
75
|
+
row = tl.program_id(0)
|
76
|
+
o_row = o_ptr + row * stride
|
77
|
+
go_row = go_ptr + row * stride
|
78
|
+
gi_row = gi_ptr + row * stride
|
79
|
+
|
80
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
81
|
+
|
82
|
+
supp_cnt = tl.zeros((), tl.float32)
|
83
|
+
go_sum = tl.zeros((), tl.float32)
|
84
|
+
|
85
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
86
|
+
offs_iter = i * BLOCK_SIZE + offs
|
87
|
+
mask_iter = offs_iter < n_cols
|
88
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
89
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
90
|
+
supp = o_val > 0.0
|
91
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
92
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
93
|
+
|
94
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
95
|
+
offs_iter = i * BLOCK_SIZE + offs
|
96
|
+
mask_iter = offs_iter < n_cols
|
97
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
98
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
99
|
+
supp = o_val > 0.0
|
100
|
+
gi_val = tl.where(
|
101
|
+
supp,
|
102
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
103
|
+
0.0,
|
104
|
+
)
|
105
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
106
|
+
|
107
|
+
|
108
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
109
|
+
@staticmethod
|
110
|
+
@ensure_contiguous
|
111
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
112
|
+
if dim < 0:
|
113
|
+
dim += x.dim()
|
114
|
+
ctx.dim = dim
|
115
|
+
|
116
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
117
|
+
n_cols = x_sw.size(-1)
|
118
|
+
n_rows = x_sw.numel() // n_cols
|
119
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
120
|
+
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
122
|
+
out_flat = torch.empty_like(x_flat)
|
123
|
+
grid = (n_rows,)
|
124
|
+
|
125
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
126
|
+
|
127
|
+
_sparsemax_forward_kernel[grid](
|
128
|
+
x_flat,
|
129
|
+
x_flat.stride(0),
|
130
|
+
x_sorted_flat,
|
131
|
+
x_sorted_flat.stride(0),
|
132
|
+
out_flat,
|
133
|
+
out_flat.stride(0),
|
134
|
+
n_cols,
|
135
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
136
|
+
num_warps=num_warps,
|
137
|
+
)
|
138
|
+
|
139
|
+
ctx.save_for_backward(out_flat)
|
140
|
+
return out_flat.view_as(x_sw).transpose(dim, -1)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
@ensure_contiguous
|
144
|
+
def backward(ctx, grad_out: torch.Tensor):
|
145
|
+
(out_flat,) = ctx.saved_tensors
|
146
|
+
dim = ctx.dim
|
147
|
+
|
148
|
+
go_sw = grad_out.transpose(dim, -1).contiguous()
|
149
|
+
n_cols = go_sw.size(-1)
|
150
|
+
n_rows = go_sw.numel() // n_cols
|
151
|
+
go_flat = go_sw.view(n_rows, n_cols)
|
152
|
+
|
153
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
154
|
+
gi_flat = torch.empty_like(go_flat)
|
155
|
+
grid = (n_rows,)
|
156
|
+
|
157
|
+
_sparsemax_backward_kernel[grid](
|
158
|
+
out_flat,
|
159
|
+
go_flat,
|
160
|
+
gi_flat,
|
161
|
+
out_flat.stride(0),
|
162
|
+
n_cols,
|
163
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
164
|
+
num_warps=num_warps,
|
165
|
+
)
|
166
|
+
|
167
|
+
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
12
12
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
13
13
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
14
14
|
from liger_kernel.ops.rope import LigerRopeFunction
|
15
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
16
17
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
17
18
|
|
@@ -159,6 +160,13 @@ def liger_kl_div(
|
|
159
160
|
)
|
160
161
|
|
161
162
|
|
163
|
+
def liger_sparsemax(
|
164
|
+
input,
|
165
|
+
dim: int = -1,
|
166
|
+
):
|
167
|
+
return LigerSparsemaxFunction.apply(input, dim)
|
168
|
+
|
169
|
+
|
162
170
|
def liger_tvd(
|
163
171
|
input,
|
164
172
|
target,
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
5
|
+
|
6
|
+
|
7
|
+
class LigerSparsemax(nn.Module):
|
8
|
+
def __init__(self, dim: int = -1):
|
9
|
+
super().__init__()
|
10
|
+
self.dim = dim
|
11
|
+
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
14
|
+
|
15
|
+
def extra_repr(self) -> str:
|
16
|
+
return f"dim={self.dim}"
|
@@ -28,6 +28,7 @@ liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYu
|
|
28
28
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
29
29
|
liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
|
30
30
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
31
|
+
liger_kernel/ops/sparsemax.py,sha256=t7JWIyzq1piikXUufayFzsfkzVaCYU-hXPuMs7839pk,4850
|
31
32
|
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
32
33
|
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
33
34
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
@@ -37,7 +38,7 @@ liger_kernel/transformers/__init__.py,sha256=0KX0rxyy0E_uNWVE0PSTzEVzKqc5KdFHtvd
|
|
37
38
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
38
39
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
39
40
|
liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
|
40
|
-
liger_kernel/transformers/functional.py,sha256=
|
41
|
+
liger_kernel/transformers/functional.py,sha256=2YBfvtdU1GRZuRpJhHgJXeGYa1RvmO6-qQvrKQrLJK4,5259
|
41
42
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
42
43
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
43
44
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
@@ -50,6 +51,7 @@ liger_kernel/transformers/monkey_patch.py,sha256=k8WIkx_f3ObG6TjhIiN_4KeOABurB2W
|
|
50
51
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
51
52
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
52
53
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
54
|
+
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
53
55
|
liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
|
54
56
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
55
57
|
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
@@ -77,9 +79,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
77
79
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
78
80
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
79
81
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
80
|
-
liger_kernel_nightly-0.5.9.
|
81
|
-
liger_kernel_nightly-0.5.9.
|
82
|
-
liger_kernel_nightly-0.5.9.
|
83
|
-
liger_kernel_nightly-0.5.9.
|
84
|
-
liger_kernel_nightly-0.5.9.
|
85
|
-
liger_kernel_nightly-0.5.9.
|
82
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
83
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/METADATA,sha256=IK7MV888DLovn85_Xto_NFKgXq4SILvZB7HDXyeP2uc,23874
|
84
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
85
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
86
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
87
|
+
liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|