liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import (
8
- calculate_settings,
9
- compare_version,
10
- ensure_contiguous,
11
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
12
10
 
13
11
  if compare_version("triton", operator.ge, "3.0.0"):
14
12
  try:
@@ -22,9 +20,7 @@ else:
22
20
 
23
21
 
24
22
  @triton.jit
25
- def _geglu_tanh_forward_kernel(
26
- a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
- ):
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
28
24
  program_id = tl.program_id(0).to(tl.int64)
29
25
 
30
26
  # locate start index
@@ -44,14 +40,12 @@ def _geglu_tanh_forward_kernel(
44
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
45
41
  tanh_result = tanh(tanh_arg)
46
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
47
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
48
44
  tl.store(c + col_offsets, c_row, mask=mask)
49
45
 
50
46
 
51
47
  @triton.jit
52
- def _geglu_tanh_backward_kernel(
53
- dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
- ):
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
55
49
  program_id = tl.program_id(0).to(tl.int64)
56
50
 
57
51
  # locate start index
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
80
74
  # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81
75
  term1 = 0.5 * (1 + tanh_result)
82
76
  tanh_sq = tanh_result * tanh_result
83
- term2 = (
84
- 0.5
85
- * a_row
86
- * (1 - tanh_sq)
87
- * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88
- )
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
89
78
  da_row = dc_row * b_row * (term1 + term2)
90
79
 
91
80
  tl.store(a + col_offsets, da_row, mask=mask)
@@ -0,0 +1,305 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import rsqrt
17
+ else:
18
+ from triton.language.math import rsqrt
19
+
20
+ MAX_FUSED_SIZE = 65536
21
+
22
+
23
+ @triton.jit
24
+ def _group_norm_forward_kernel(
25
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
26
+ Y_row_stride, # stride of each row in output
27
+ Y_col_stride, # stride of each column in output
28
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
29
+ X_row_stride, # stride of each row in input
30
+ X_col_stride, # stride of each column in input
31
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
32
+ Mean_row_stride, # stride of each row in mean
33
+ Mean_col_stride, # stride of each column in mean
34
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
35
+ RSTD_row_stride, # stride of each row in rstd
36
+ RSTD_col_stride, # stride of each column in rstd
37
+ W_ptr, # pointer to W
38
+ B_ptr, # pointer to B
39
+ hidden_size, # hidden size of X
40
+ channels_per_group, # the number of channels per group
41
+ eps,
42
+ BLOCK_SIZE: tl.constexpr,
43
+ ):
44
+ """
45
+ References:
46
+ https://nn.labml.ai/normalization/group_norm/index.html
47
+ """
48
+ batch_idx = tl.program_id(0)
49
+ group_idx = tl.program_id(1)
50
+
51
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
52
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
53
+
54
+ block_range = tl.arange(0, BLOCK_SIZE)
55
+
56
+ # Compute mean and variance using the online algorithm
57
+ s = 0.0
58
+ squared_sum = 0.0
59
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
60
+ hidden_size_offsets = i + block_range
61
+ mask = hidden_size_offsets < hidden_size
62
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
63
+ s += tl.sum(X)
64
+ # X**2
65
+ squared_sum += tl.sum(X * X)
66
+
67
+ m = s / hidden_size
68
+
69
+ # variance = E[X**2] - E[X]**2
70
+ variance = (squared_sum / hidden_size) - (m * m)
71
+
72
+ # 1/std
73
+ rstd = rsqrt(variance + eps)
74
+
75
+ # Normalize
76
+ hidden_size_per_channel = hidden_size // channels_per_group
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
+ W = tl.load(W_ptr + channel_idx)
79
+ B = tl.load(B_ptr + channel_idx)
80
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
+ hidden_size_offsets = i + block_range
82
+ mask = hidden_size_offsets < hidden_size_per_channel
83
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
84
+ Y = (X - m) * rstd * W + B
85
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
+
87
+ X_ptr += hidden_size_per_channel
88
+ Y_ptr += hidden_size_per_channel
89
+
90
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
92
+
93
+
94
+ @triton.jit
95
+ def _group_norm_backward_kernel(
96
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
97
+ X_row_stride, # stride of each row in input
98
+ X_col_stride, # stride of each column in input
99
+ W_ptr, # pointer to weights, shape (n_channels)
100
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
101
+ Mean_ptr_row_stride, # stride of each column in mean
102
+ Mean_ptr_col_stride, # stride of each column in mean
103
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
104
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
105
+ DW_ptr, # pointer to weights grad, shape (n_channels)
106
+ DB_ptr, # pointer to bias grad, shape (n_channels)
107
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
108
+ hidden_size: tl.constexpr, # hidden size
109
+ channels_per_group: tl.constexpr, # number of groups in group norm
110
+ BLOCK_SIZE: tl.constexpr,
111
+ dtype: tl.constexpr,
112
+ ):
113
+ """
114
+ References:
115
+ https://nn.labml.ai/normalization/group_norm/index.html
116
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
117
+
118
+ The backprop equations are the same for group_norm and layer_norm
119
+ the only difference here is that we load the Mean, Rstd corresponding to the
120
+ group we're computing gradients for and the mean and rstd are computed over n-channels
121
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
122
+
123
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
124
+ """
125
+ batch_idx = tl.program_id(0)
126
+ group_idx = tl.program_id(1)
127
+
128
+ # Move the pointers to the correct batch
129
+ X_ptr += batch_idx * X_row_stride
130
+ DX_ptr += batch_idx * X_row_stride
131
+ UPSTREAM_ptr += batch_idx * X_row_stride
132
+
133
+ # Mean and rstd are the same shape so have the same strides
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
136
+
137
+ c1 = 0.0
138
+ c2 = 0.0
139
+ block_range = tl.arange(0, BLOCK_SIZE)
140
+
141
+ # We need to compute the sum terms of the backprop equations across all channels in the group
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
143
+ dW = 0.0
144
+ dB = 0.0
145
+ # Move the pointers to the correct channel
146
+ W = tl.load(W_ptr + channel_idx)
147
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
148
+ hidden_size_offsets = i + block_range
149
+ mask = hidden_size_offsets < hidden_size
150
+ X = tl.load(
151
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
152
+ mask=mask,
153
+ other=0.0,
154
+ )
155
+ UPSTREAM_grad = tl.load(
156
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
157
+ mask=mask,
158
+ other=0.0,
159
+ )
160
+
161
+ x_hat = (X - mean) * rstd
162
+ dW += tl.sum(UPSTREAM_grad * x_hat)
163
+ dB += tl.sum(UPSTREAM_grad)
164
+
165
+ wdy = W * UPSTREAM_grad
166
+ c1 += tl.sum(x_hat * wdy)
167
+ c2 += tl.sum(wdy)
168
+
169
+ # Need to ensure additions to the same channel are atomic
170
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
171
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
172
+
173
+ N = hidden_size * channels_per_group
174
+ c1 = c1 / N
175
+ c2 = c2 / N
176
+
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
178
+ # Move the pointers to the correct channel
179
+ W = tl.load(W_ptr + channel_idx)
180
+ for i in range(0, hidden_size, BLOCK_SIZE):
181
+ hidden_size_offsets = i + block_range
182
+ mask = hidden_size_offsets < hidden_size
183
+ X = tl.load(
184
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
185
+ mask=mask,
186
+ other=0.0,
187
+ )
188
+ UPSTREAM_grad = tl.load(
189
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
190
+ mask=mask,
191
+ other=0.0,
192
+ )
193
+
194
+ x_hat = (X - mean) * rstd
195
+ wdy = W * UPSTREAM_grad
196
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
198
+
199
+
200
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
201
+ shape = X.shape
202
+ batch_size = shape[0]
203
+ channels_per_group = num_channels // num_groups
204
+ # Reshape X so that the mean and std are computed across the groups
205
+ X = X.view(batch_size, num_groups, -1).contiguous()
206
+ hidden_size = X.shape[-1]
207
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
209
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
210
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
211
+
212
+ _group_norm_forward_kernel[(batch_size, num_groups)](
213
+ Y,
214
+ Y.stride(0),
215
+ Y.stride(1),
216
+ X,
217
+ X.stride(0),
218
+ X.stride(1),
219
+ Mean,
220
+ Mean.stride(0),
221
+ Mean.stride(1),
222
+ RSTD,
223
+ RSTD.stride(0),
224
+ RSTD.stride(1),
225
+ W,
226
+ B,
227
+ hidden_size,
228
+ channels_per_group,
229
+ eps,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ )
232
+ # Return tensors in the original shape
233
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
234
+
235
+
236
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
237
+ shape = dY.shape
238
+ batch_size = shape[0]
239
+ hidden_size = dY.shape[-1]
240
+ channels_per_group = num_channels // num_groups
241
+ dY = dY.view(batch_size, num_groups, -1)
242
+ DX = torch.empty(
243
+ (batch_size, num_groups, hidden_size * channels_per_group),
244
+ dtype=X.dtype,
245
+ device=X.device,
246
+ )
247
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
248
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
249
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
250
+
251
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
252
+ _group_norm_backward_kernel[(batch_size, num_groups)](
253
+ X,
254
+ X.stride(0),
255
+ X.stride(1),
256
+ W,
257
+ Mean,
258
+ Mean.stride(0),
259
+ Mean.stride(1),
260
+ RSTD,
261
+ DX,
262
+ DW,
263
+ DB,
264
+ dY,
265
+ hidden_size,
266
+ channels_per_group,
267
+ BLOCK_SIZE=BLOCK_SIZE,
268
+ dtype=triton_dtype,
269
+ )
270
+
271
+ # Return tensors in the original shape
272
+ return DX.view(*shape), DW, DB
273
+
274
+
275
+ class LigerGroupNormFunction(torch.autograd.Function):
276
+ @staticmethod
277
+ @ensure_contiguous
278
+ def forward(
279
+ ctx,
280
+ X,
281
+ affine_scaling_weight,
282
+ affine_shifting_bias,
283
+ num_channels,
284
+ num_groups,
285
+ eps,
286
+ ):
287
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
288
+ X,
289
+ num_channels,
290
+ num_groups,
291
+ affine_scaling_weight,
292
+ affine_shifting_bias,
293
+ eps,
294
+ )
295
+ ctx.num_channels = num_channels
296
+ ctx.num_groups = num_groups
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
298
+ return Y
299
+
300
+ @staticmethod
301
+ @ensure_contiguous
302
+ def backward(ctx, dY):
303
+ X, W, B, Mean, RSTD = ctx.saved_tensors
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
305
+ return DX, DW, DB, None, None, None
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None