liger-kernel-nightly 0.5.9.dev20250512213150__py3-none-any.whl → 0.5.9.dev20250515065336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def _sparsemax_forward_kernel(
11
+ x_ptr,
12
+ x_stride_row,
13
+ sorted_x_ptr,
14
+ sorted_x_stride_row,
15
+ o_ptr,
16
+ o_stride_row,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ num_warps: tl.constexpr,
20
+ ):
21
+ pid_row = tl.program_id(0)
22
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
23
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
24
+ ptr_output_row = o_ptr + pid_row * o_stride_row
25
+
26
+ offs = tl.arange(0, BLOCK_SIZE)
27
+ mask = offs < n_cols
28
+
29
+ z_sorted_block = tl.load(
30
+ ptr_sorted_x_data_row + offs,
31
+ mask=mask,
32
+ other=-float("inf"),
33
+ cache_modifier=".ca",
34
+ ).to(tl.float32)
35
+
36
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
37
+ cssv = tl.cumsum(z_valid, 0)
38
+
39
+ r = (offs + 1).to(tl.float32)
40
+ safe_r = tl.where(mask, r, 1.0)
41
+
42
+ t_vec = (cssv - 1.0) / safe_r
43
+
44
+ support = (z_sorted_block > t_vec) & mask
45
+
46
+ k_int = tl.sum(support.to(tl.int32), 0)
47
+ k_clamped_int = tl.maximum(k_int, 1)
48
+ k = k_clamped_int.to(tl.float32)
49
+
50
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
51
+
52
+ tau = (s - 1.0) / k
53
+
54
+ x_block = tl.load(
55
+ ptr_x_data_row + offs,
56
+ mask=mask,
57
+ other=0.0,
58
+ cache_modifier=".ca",
59
+ ).to(tl.float32)
60
+
61
+ y = tl.maximum(x_block - tau, 0.0)
62
+
63
+ tl.store(
64
+ ptr_output_row + offs,
65
+ y.to(ptr_output_row.dtype.element_ty),
66
+ mask=mask,
67
+ cache_modifier=".cs",
68
+ )
69
+
70
+
71
+ @triton.jit
72
+ def _sparsemax_backward_kernel(
73
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
74
+ ):
75
+ row = tl.program_id(0)
76
+ o_row = o_ptr + row * stride
77
+ go_row = go_ptr + row * stride
78
+ gi_row = gi_ptr + row * stride
79
+
80
+ offs = tl.arange(0, BLOCK_SIZE)
81
+
82
+ supp_cnt = tl.zeros((), tl.float32)
83
+ go_sum = tl.zeros((), tl.float32)
84
+
85
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
86
+ offs_iter = i * BLOCK_SIZE + offs
87
+ mask_iter = offs_iter < n_cols
88
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
89
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
90
+ supp = o_val > 0.0
91
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
92
+ supp_cnt += tl.sum(supp.to(tl.float32))
93
+
94
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
95
+ offs_iter = i * BLOCK_SIZE + offs
96
+ mask_iter = offs_iter < n_cols
97
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
98
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
99
+ supp = o_val > 0.0
100
+ gi_val = tl.where(
101
+ supp,
102
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
103
+ 0.0,
104
+ )
105
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
+
107
+
108
+ class LigerSparsemaxFunction(torch.autograd.Function):
109
+ @staticmethod
110
+ @ensure_contiguous
111
+ def forward(ctx, x: torch.Tensor, dim: int):
112
+ if dim < 0:
113
+ dim += x.dim()
114
+ ctx.dim = dim
115
+
116
+ x_sw = x.transpose(dim, -1).contiguous()
117
+ n_cols = x_sw.size(-1)
118
+ n_rows = x_sw.numel() // n_cols
119
+ x_flat = x_sw.view(n_rows, n_cols)
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ out_flat = torch.empty_like(x_flat)
123
+ grid = (n_rows,)
124
+
125
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
+
127
+ _sparsemax_forward_kernel[grid](
128
+ x_flat,
129
+ x_flat.stride(0),
130
+ x_sorted_flat,
131
+ x_sorted_flat.stride(0),
132
+ out_flat,
133
+ out_flat.stride(0),
134
+ n_cols,
135
+ BLOCK_SIZE=BLOCK_SIZE,
136
+ num_warps=num_warps,
137
+ )
138
+
139
+ ctx.save_for_backward(out_flat)
140
+ return out_flat.view_as(x_sw).transpose(dim, -1)
141
+
142
+ @staticmethod
143
+ @ensure_contiguous
144
+ def backward(ctx, grad_out: torch.Tensor):
145
+ (out_flat,) = ctx.saved_tensors
146
+ dim = ctx.dim
147
+
148
+ go_sw = grad_out.transpose(dim, -1).contiguous()
149
+ n_cols = go_sw.size(-1)
150
+ n_rows = go_sw.numel() // n_cols
151
+ go_flat = go_sw.view(n_rows, n_cols)
152
+
153
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
+ gi_flat = torch.empty_like(go_flat)
155
+ grid = (n_rows,)
156
+
157
+ _sparsemax_backward_kernel[grid](
158
+ out_flat,
159
+ go_flat,
160
+ gi_flat,
161
+ out_flat.stride(0),
162
+ n_cols,
163
+ BLOCK_SIZE=BLOCK_SIZE,
164
+ num_warps=num_warps,
165
+ )
166
+
167
+ return gi_flat.view_as(go_sw).transpose(dim, -1), None
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
12
12
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
13
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
14
  from liger_kernel.ops.rope import LigerRopeFunction
15
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
16
17
  from liger_kernel.ops.tvd import LigerTVDLossFunction
17
18
 
@@ -159,6 +160,13 @@ def liger_kl_div(
159
160
  )
160
161
 
161
162
 
163
+ def liger_sparsemax(
164
+ input,
165
+ dim: int = -1,
166
+ ):
167
+ return LigerSparsemaxFunction.apply(input, dim)
168
+
169
+
162
170
  def liger_tvd(
163
171
  input,
164
172
  target,
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250512213150
3
+ Version: 0.5.9.dev20250515065336
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -28,6 +28,7 @@ liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYu
28
28
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
29
29
  liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
30
30
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
31
+ liger_kernel/ops/sparsemax.py,sha256=t7JWIyzq1piikXUufayFzsfkzVaCYU-hXPuMs7839pk,4850
31
32
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
32
33
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
33
34
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
@@ -37,7 +38,7 @@ liger_kernel/transformers/__init__.py,sha256=0KX0rxyy0E_uNWVE0PSTzEVzKqc5KdFHtvd
37
38
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
38
39
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
39
40
  liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
- liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
41
+ liger_kernel/transformers/functional.py,sha256=2YBfvtdU1GRZuRpJhHgJXeGYa1RvmO6-qQvrKQrLJK4,5259
41
42
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
42
43
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
43
44
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -50,6 +51,7 @@ liger_kernel/transformers/monkey_patch.py,sha256=k8WIkx_f3ObG6TjhIiN_4KeOABurB2W
50
51
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
51
52
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
52
53
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
54
+ liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
53
55
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
54
56
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
55
57
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
@@ -77,9 +79,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
77
79
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
78
80
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
79
81
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
80
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
81
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/METADATA,sha256=iVTABeE0sZWm8MpMWZGh5nRvdRUTm0CKIvo1lSrX7c8,23874
82
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
83
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
84
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
85
- liger_kernel_nightly-0.5.9.dev20250512213150.dist-info/RECORD,,
82
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
83
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/METADATA,sha256=IK7MV888DLovn85_Xto_NFKgXq4SILvZB7HDXyeP2uc,23874
84
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
85
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
86
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
87
+ liger_kernel_nightly-0.5.9.dev20250515065336.dist-info/RECORD,,