liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx
@@ -0,0 +1,179 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _sparsemax_forward_kernel(
13
+ x_ptr,
14
+ x_stride_row,
15
+ sorted_x_ptr,
16
+ sorted_x_stride_row,
17
+ o_ptr,
18
+ o_stride_row,
19
+ n_cols,
20
+ BLOCK_SIZE: tl.constexpr,
21
+ num_warps: tl.constexpr,
22
+ ):
23
+ pid_row = tl.program_id(0)
24
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
25
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
26
+ ptr_output_row = o_ptr + pid_row * o_stride_row
27
+
28
+ offs = tl.arange(0, BLOCK_SIZE)
29
+ mask = offs < n_cols
30
+
31
+ z_sorted_block = tl.load(
32
+ ptr_sorted_x_data_row + offs,
33
+ mask=mask,
34
+ other=-float("inf"),
35
+ cache_modifier=".ca",
36
+ ).to(tl.float32)
37
+
38
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
39
+ cssv = tl.cumsum(z_valid, 0)
40
+
41
+ r = (offs + 1).to(tl.float32)
42
+ safe_r = tl.where(mask, r, 1.0)
43
+
44
+ t_vec = (cssv - 1.0) / safe_r
45
+
46
+ support = (z_sorted_block > t_vec) & mask
47
+
48
+ k_int = tl.sum(support.to(tl.int32), 0)
49
+ k_clamped_int = tl.maximum(k_int, 1)
50
+ k = k_clamped_int.to(tl.float32)
51
+
52
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
53
+
54
+ tau = (s - 1.0) / k
55
+
56
+ x_block = tl.load(
57
+ ptr_x_data_row + offs,
58
+ mask=mask,
59
+ other=0.0,
60
+ cache_modifier=".ca",
61
+ ).to(tl.float32)
62
+
63
+ y = tl.maximum(x_block - tau, 0.0)
64
+
65
+ tl.store(
66
+ ptr_output_row + offs,
67
+ y.to(ptr_output_row.dtype.element_ty),
68
+ mask=mask,
69
+ cache_modifier=".cs",
70
+ )
71
+
72
+
73
+ @triton.jit
74
+ def _sparsemax_backward_kernel(
75
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
76
+ ):
77
+ row = tl.program_id(0)
78
+ o_row = o_ptr + row * stride
79
+ go_row = go_ptr + row * stride
80
+ gi_row = gi_ptr + row * stride
81
+
82
+ offs = tl.arange(0, BLOCK_SIZE)
83
+
84
+ supp_cnt = tl.zeros((), tl.float32)
85
+ go_sum = tl.zeros((), tl.float32)
86
+
87
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
88
+ offs_iter = i * BLOCK_SIZE + offs
89
+ mask_iter = offs_iter < n_cols
90
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
91
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
92
+ supp = o_val > 0.0
93
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
94
+ supp_cnt += tl.sum(supp.to(tl.float32))
95
+
96
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
97
+ offs_iter = i * BLOCK_SIZE + offs
98
+ mask_iter = offs_iter < n_cols
99
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
100
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
101
+ supp = o_val > 0.0
102
+ gi_val = tl.where(
103
+ supp,
104
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
105
+ 0.0,
106
+ )
107
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
108
+
109
+
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
165
+ class LigerSparsemaxFunction(torch.autograd.Function):
166
+ @staticmethod
167
+ @ensure_contiguous
168
+ def forward(ctx, x: torch.Tensor, dim: int):
169
+ y, out_flat = _sparsemax_forward(x, dim)
170
+ ctx.save_for_backward(out_flat)
171
+ ctx.dim = dim
172
+ return y
173
+
174
+ @staticmethod
175
+ @ensure_contiguous
176
+ def backward(ctx, grad_out: torch.Tensor):
177
+ (out_flat,) = ctx.saved_tensors
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None
@@ -2,7 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
6
7
 
7
8
 
8
9
  @triton.jit
@@ -11,9 +12,7 @@ def silu(x):
11
12
 
12
13
 
13
14
  @triton.jit
14
- def _swiglu_forward_kernel(
15
- a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
- ):
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
17
16
  program_id = tl.program_id(0).to(tl.int64)
18
17
 
19
18
  # locate start index
@@ -27,14 +26,12 @@ def _swiglu_forward_kernel(
27
26
  # sigmoid requires type float32
28
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
29
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
30
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
31
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
32
31
 
33
32
 
34
33
  @triton.jit
35
- def _swiglu_backward_kernel(
36
- dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
- ):
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
38
35
  program_id = tl.program_id(0).to(tl.int64)
39
36
 
40
37
  # locate start index
@@ -84,7 +81,6 @@ def swiglu_forward(a, b):
84
81
 
85
82
 
86
83
  def swiglu_backward(a, b, dc):
87
-
88
84
  ori_shape = dc.shape
89
85
  n_cols = ori_shape[-1]
90
86
  dc = dc.view(-1, n_cols)
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )