liger-kernel 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/dyt.py +113 -179
  5. liger_kernel/ops/grpo_loss.py +310 -0
  6. liger_kernel/ops/sparsemax.py +167 -0
  7. liger_kernel/transformers/__init__.py +5 -0
  8. liger_kernel/transformers/dyt.py +5 -3
  9. liger_kernel/transformers/fsdp.py +55 -0
  10. liger_kernel/transformers/functional.py +8 -0
  11. liger_kernel/transformers/grpo_loss.py +98 -0
  12. liger_kernel/transformers/model/gemma.py +0 -8
  13. liger_kernel/transformers/model/gemma2.py +0 -6
  14. liger_kernel/transformers/model/gemma3.py +0 -8
  15. liger_kernel/transformers/model/glm4.py +0 -6
  16. liger_kernel/transformers/model/llama.py +56 -11
  17. liger_kernel/transformers/model/llava.py +0 -8
  18. liger_kernel/transformers/model/mistral.py +0 -6
  19. liger_kernel/transformers/model/mixtral.py +0 -8
  20. liger_kernel/transformers/model/mllama.py +0 -7
  21. liger_kernel/transformers/model/olmo2.py +0 -6
  22. liger_kernel/transformers/model/paligemma.py +0 -8
  23. liger_kernel/transformers/model/phi3.py +0 -8
  24. liger_kernel/transformers/model/qwen2.py +0 -8
  25. liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
  26. liger_kernel/transformers/model/qwen2_vl.py +0 -6
  27. liger_kernel/transformers/model/qwen3.py +0 -6
  28. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  29. liger_kernel/transformers/monkey_patch.py +122 -13
  30. liger_kernel/transformers/sparsemax.py +16 -0
  31. liger_kernel/transformers/swiglu.py +21 -0
  32. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  33. liger_kernel/utils.py +11 -0
  34. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +34 -20
  35. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +39 -33
  36. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  37. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  38. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  39. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -128,7 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
128
128
  compute_nll_loss: bool = False,
129
129
  compiled: bool = True,
130
130
  use_ref_model: bool = True,
131
- average_log_prob: bool = True,
131
+ average_log_prob: bool = False,
132
132
  chunk_size: int = 1,
133
133
  ):
134
134
  """
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
224
  (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
- strict=False,
226
225
  ):
227
226
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
228
227
  ref_input_chunk = (
@@ -150,8 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
150
150
  teacher_input: torch.Tensor,
151
151
  teacher_weight: torch.Tensor,
152
152
  true_labels: torch.LongTensor,
153
- student_bias: torch.Tensor,
154
- teacher_bias: torch.Tensor,
153
+ student_bias: torch.Tensor = None,
154
+ teacher_bias: torch.Tensor = None,
155
155
  ) -> torch.Tensor:
156
156
  """
157
157
  Compute the JSD distillation loss.
liger_kernel/ops/dyt.py CHANGED
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import calculate_settings
7
+ from triton.language.extra.libdevice import tanh
8
+
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
10
11
  from liger_kernel.ops.utils import infer_device
@@ -20,187 +21,126 @@ else:
20
21
  from triton.language.math import tanh
21
22
 
22
23
 
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
23
30
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
58
41
 
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
59
43
 
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
60
58
  @triton.jit
61
59
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
77
61
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
133
88
 
134
89
 
135
90
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
140
97
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
153
114
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
166
125
  device = infer_device()
167
126
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
169
128
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+
131
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
134
+ dx = torch.empty_like(dy)
135
+
136
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
137
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
138
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
139
+ if HAVE_BETA:
140
+ db = db.sum(0).to(x.dtype)
141
+ dg = dg.sum(0).to(gamma.dtype)
142
+ da = da.sum().to(x.dtype).unsqueeze(0)
143
+ return dx.view(input_shape), da, dg, db
204
144
 
205
145
 
206
146
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +148,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
148
  @ensure_contiguous
209
149
  def forward(ctx, x, alpha, gamma, beta):
210
150
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
151
+ ctx.save_for_backward(x, alpha, gamma, beta)
212
152
  return y
213
153
 
214
154
  @staticmethod
215
155
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
156
+ def backward(ctx, dy):
157
+ x, alpha, gamma, beta = ctx.saved_tensors
158
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
159
+ return dx, dalpha, dgamma, dbeta
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None