liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rope.py CHANGED
@@ -13,8 +13,9 @@ def _triton_rope(
13
13
  cos_row_stride,
14
14
  sin,
15
15
  sin_row_stride,
16
+ sl,
16
17
  bs: tl.constexpr,
17
- sl: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,9 +30,9 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
35
36
 
36
37
  # locate start address
37
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -61,36 +72,20 @@ def _triton_rope(
61
72
  # program instance (i.e. for the current token) separately
62
73
  # ####################################################################
63
74
  # left half of the head
64
- first_half_q_offsets = (
65
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
- )
67
- first_half_k_offsets = (
68
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
- )
70
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
- )
73
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
- )
76
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
- sin_row.dtype
78
- )
79
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
- sin_row.dtype
81
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
82
81
 
83
82
  # right half of the head
84
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
85
  second_q_mask = first_q_mask
87
86
  second_k_mask = first_k_mask
88
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
- sin_row.dtype
90
- )
91
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
- sin_row.dtype
93
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
94
89
 
95
90
  if not BACKWARD_PASS:
96
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -117,6 +112,95 @@ def _triton_rope(
117
112
  tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
118
113
 
119
114
 
115
+ def rope_forward(q, k, cos, sin):
116
+ # transpose it back to the physical shape because Triton looks at the physical storage
117
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
118
+ q = q.transpose(1, 2)
119
+ k = k.transpose(1, 2)
120
+
121
+ batch_size, seq_len, n_q_head, head_dim = q.shape
122
+ n_kv_head = k.shape[2]
123
+ pad_hd = triton.next_power_of_2(head_dim)
124
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
125
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
126
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
127
+
128
+ n_row = batch_size * seq_len
129
+
130
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
131
+ q = q.contiguous()
132
+ k = k.contiguous()
133
+ cos = cos.contiguous()
134
+ sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
136
+
137
+ _triton_rope[(n_row,)](
138
+ q,
139
+ q.stride(1),
140
+ k,
141
+ k.stride(1),
142
+ cos,
143
+ cos.stride(-2),
144
+ sin,
145
+ sin.stride(-2),
146
+ seq_len,
147
+ batch_size,
148
+ cos_batch_size,
149
+ n_q_head,
150
+ n_kv_head,
151
+ head_dim,
152
+ pad_n_q_head,
153
+ pad_n_kv_head,
154
+ pad_hd,
155
+ BLOCK_SIZE=BLOCK_SIZE,
156
+ BACKWARD_PASS=False,
157
+ )
158
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
159
+
160
+
161
+ def rope_backward(dq, dk, cos, sin):
162
+ dq = dq.transpose(1, 2)
163
+ dk = dk.transpose(1, 2)
164
+
165
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
167
+ n_kv_head = dk.shape[2]
168
+ pad_hd = triton.next_power_of_2(head_dim)
169
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
170
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
171
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
172
+
173
+ n_row = batch_size * seq_len
174
+
175
+ # ensure dq and dk are contiguous
176
+ dq = dq.contiguous()
177
+ dk = dk.contiguous()
178
+
179
+ # backward is similar to forward except swapping few ops
180
+ _triton_rope[(n_row,)](
181
+ dq,
182
+ dq.stride(1),
183
+ dk,
184
+ dk.stride(1),
185
+ cos,
186
+ cos.stride(-2),
187
+ sin,
188
+ sin.stride(-2),
189
+ seq_len,
190
+ batch_size,
191
+ cos_batch_size,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ pad_n_q_head,
196
+ pad_n_kv_head,
197
+ pad_hd,
198
+ BLOCK_SIZE=BLOCK_SIZE,
199
+ BACKWARD_PASS=True,
200
+ )
201
+ return dq.transpose(1, 2), dk.transpose(1, 2)
202
+
203
+
120
204
  class LigerRopeFunction(torch.autograd.Function):
121
205
  """
122
206
  Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
@@ -135,100 +219,21 @@ class LigerRopeFunction(torch.autograd.Function):
135
219
  """
136
220
  q size: (bsz, n_q_head, seq_len, head_dim)
137
221
  k size: (bsz, n_kv_head, seq_len, head_dim)
138
- cos size: (1, seq_len, head_dim)
139
- sin size: (1, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
140
224
  """
141
-
142
- # transpose it back to the physical shape because Triton looks at the physical storage
143
- # note: q and k are incontiguous before the transformation and will become contiguous after transpose
144
- q = q.transpose(1, 2)
145
- k = k.transpose(1, 2)
146
-
147
- batch_size, seq_len, n_q_head, head_dim = q.shape
148
- n_kv_head = k.shape[2]
149
- pad_hd = triton.next_power_of_2(head_dim)
150
- pad_n_q_head = triton.next_power_of_2(n_q_head)
151
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
152
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
153
-
154
- n_row = batch_size * seq_len
155
-
156
- # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
157
- q = q.contiguous()
158
- k = k.contiguous()
159
- cos = cos.contiguous()
160
- sin = sin.contiguous()
161
-
162
- _triton_rope[(n_row,)](
163
- q,
164
- q.stride(1),
165
- k,
166
- k.stride(1),
167
- cos,
168
- cos.stride(-2),
169
- sin,
170
- sin.stride(-2),
171
- batch_size,
172
- seq_len,
173
- n_q_head,
174
- n_kv_head,
175
- head_dim,
176
- pad_n_q_head,
177
- pad_n_kv_head,
178
- pad_hd,
179
- BLOCK_SIZE=BLOCK_SIZE,
180
- BACKWARD_PASS=False,
181
- )
182
-
225
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
183
226
  ctx.save_for_backward(cos, sin)
184
- return q.transpose(1, 2), k.transpose(1, 2)
227
+ return q, k
185
228
 
186
229
  def backward(ctx, dq, dk):
187
230
  """
188
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
189
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
190
- cos size: (1, seq_len, head_dim)
191
- sin size: (1, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
192
235
  """
193
236
 
194
237
  cos, sin = ctx.saved_tensors
195
-
196
- dq = dq.transpose(1, 2)
197
- dk = dk.transpose(1, 2)
198
-
199
- batch_size, seq_len, n_q_head, head_dim = dq.shape
200
- n_kv_head = dk.shape[2]
201
- pad_hd = triton.next_power_of_2(head_dim)
202
- pad_n_q_head = triton.next_power_of_2(n_q_head)
203
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
204
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
205
-
206
- n_row = batch_size * seq_len
207
-
208
- # ensure dq and dk are contiguous
209
- dq = dq.contiguous()
210
- dk = dk.contiguous()
211
-
212
- # backward is similar to forward except swapping few ops
213
- _triton_rope[(n_row,)](
214
- dq,
215
- dq.stride(1),
216
- dk,
217
- dk.stride(1),
218
- cos,
219
- cos.stride(-2),
220
- sin,
221
- sin.stride(-2),
222
- batch_size,
223
- seq_len,
224
- n_q_head,
225
- n_kv_head,
226
- head_dim,
227
- pad_n_q_head,
228
- pad_n_kv_head,
229
- pad_hd,
230
- BLOCK_SIZE=BLOCK_SIZE,
231
- BACKWARD_PASS=True,
232
- )
233
-
234
- return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
238
+ dq, dk = rope_backward(dq, dk, cos, sin)
239
+ return dq, dk, None, None, None, None
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx
@@ -0,0 +1,179 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _sparsemax_forward_kernel(
13
+ x_ptr,
14
+ x_stride_row,
15
+ sorted_x_ptr,
16
+ sorted_x_stride_row,
17
+ o_ptr,
18
+ o_stride_row,
19
+ n_cols,
20
+ BLOCK_SIZE: tl.constexpr,
21
+ num_warps: tl.constexpr,
22
+ ):
23
+ pid_row = tl.program_id(0)
24
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
25
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
26
+ ptr_output_row = o_ptr + pid_row * o_stride_row
27
+
28
+ offs = tl.arange(0, BLOCK_SIZE)
29
+ mask = offs < n_cols
30
+
31
+ z_sorted_block = tl.load(
32
+ ptr_sorted_x_data_row + offs,
33
+ mask=mask,
34
+ other=-float("inf"),
35
+ cache_modifier=".ca",
36
+ ).to(tl.float32)
37
+
38
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
39
+ cssv = tl.cumsum(z_valid, 0)
40
+
41
+ r = (offs + 1).to(tl.float32)
42
+ safe_r = tl.where(mask, r, 1.0)
43
+
44
+ t_vec = (cssv - 1.0) / safe_r
45
+
46
+ support = (z_sorted_block > t_vec) & mask
47
+
48
+ k_int = tl.sum(support.to(tl.int32), 0)
49
+ k_clamped_int = tl.maximum(k_int, 1)
50
+ k = k_clamped_int.to(tl.float32)
51
+
52
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
53
+
54
+ tau = (s - 1.0) / k
55
+
56
+ x_block = tl.load(
57
+ ptr_x_data_row + offs,
58
+ mask=mask,
59
+ other=0.0,
60
+ cache_modifier=".ca",
61
+ ).to(tl.float32)
62
+
63
+ y = tl.maximum(x_block - tau, 0.0)
64
+
65
+ tl.store(
66
+ ptr_output_row + offs,
67
+ y.to(ptr_output_row.dtype.element_ty),
68
+ mask=mask,
69
+ cache_modifier=".cs",
70
+ )
71
+
72
+
73
+ @triton.jit
74
+ def _sparsemax_backward_kernel(
75
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
76
+ ):
77
+ row = tl.program_id(0)
78
+ o_row = o_ptr + row * stride
79
+ go_row = go_ptr + row * stride
80
+ gi_row = gi_ptr + row * stride
81
+
82
+ offs = tl.arange(0, BLOCK_SIZE)
83
+
84
+ supp_cnt = tl.zeros((), tl.float32)
85
+ go_sum = tl.zeros((), tl.float32)
86
+
87
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
88
+ offs_iter = i * BLOCK_SIZE + offs
89
+ mask_iter = offs_iter < n_cols
90
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
91
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
92
+ supp = o_val > 0.0
93
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
94
+ supp_cnt += tl.sum(supp.to(tl.float32))
95
+
96
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
97
+ offs_iter = i * BLOCK_SIZE + offs
98
+ mask_iter = offs_iter < n_cols
99
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
100
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
101
+ supp = o_val > 0.0
102
+ gi_val = tl.where(
103
+ supp,
104
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
105
+ 0.0,
106
+ )
107
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
108
+
109
+
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
165
+ class LigerSparsemaxFunction(torch.autograd.Function):
166
+ @staticmethod
167
+ @ensure_contiguous
168
+ def forward(ctx, x: torch.Tensor, dim: int):
169
+ y, out_flat = _sparsemax_forward(x, dim)
170
+ ctx.save_for_backward(out_flat)
171
+ ctx.dim = dim
172
+ return y
173
+
174
+ @staticmethod
175
+ @ensure_contiguous
176
+ def backward(ctx, grad_out: torch.Tensor):
177
+ (out_flat,) = ctx.saved_tensors
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None