liger-kernel-nightly 0.5.9.dev20250517045825__py3-none-any.whl → 0.5.9.dev20250519015630__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liger_kernel/ops/dyt.py CHANGED
@@ -4,7 +4,8 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import calculate_settings
7
+ from triton.language.extra.libdevice import tanh
8
+
8
9
  from liger_kernel.ops.utils import compare_version
9
10
  from liger_kernel.ops.utils import ensure_contiguous
10
11
  from liger_kernel.ops.utils import infer_device
@@ -20,187 +21,126 @@ else:
20
21
  from triton.language.math import tanh
21
22
 
22
23
 
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
23
30
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
58
41
 
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
59
43
 
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
60
58
  @triton.jit
61
59
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
77
61
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
133
88
 
134
89
 
135
90
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
140
97
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
153
114
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
166
125
  device = infer_device()
167
126
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
169
128
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+
131
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
134
+ dx = torch.empty_like(dy)
135
+
136
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
137
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
138
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
139
+ if HAVE_BETA:
140
+ db = db.sum(0).to(x.dtype)
141
+ dg = dg.sum(0).to(gamma.dtype)
142
+ da = da.sum().to(x.dtype).unsqueeze(0)
143
+ return dx.view(input_shape), da, dg, db
204
144
 
205
145
 
206
146
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +148,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
148
  @ensure_contiguous
209
149
  def forward(ctx, x, alpha, gamma, beta):
210
150
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
151
+ ctx.save_for_backward(x, alpha, gamma, beta)
212
152
  return y
213
153
 
214
154
  @staticmethod
215
155
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
156
+ def backward(ctx, dy):
157
+ x, alpha, gamma, beta = ctx.saved_tensors
158
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
159
+ return dx, dalpha, dgamma, dbeta
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
8
- def __init__(self, hidden_size, init_alpha=0.5):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
9
  super().__init__()
10
10
  self.hidden_size = hidden_size
11
11
  self.init_alpha = init_alpha
12
12
  self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
13
  self.gamma = nn.Parameter(torch.ones(hidden_size))
14
- self.beta = nn.Parameter(torch.zeros(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
17
 
16
18
  def forward(self, x):
17
19
  return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
20
 
19
21
  def extra_repr(self):
20
- return f"{self.hidden_size}, init_alpha={self.init_alpha}"
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
@@ -0,0 +1,98 @@
1
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
+
3
+
4
+ def triton_grpo_loss(
5
+ logits,
6
+ old_logp,
7
+ ref_logp,
8
+ completion_ids,
9
+ advantages,
10
+ completion_mask=None,
11
+ temperature=0.9,
12
+ beta=0.04,
13
+ eps_low=0.2,
14
+ eps_high=0.4,
15
+ inplace=True,
16
+ ):
17
+ assert logits is not None and completion_ids is not None and advantages is not None, (
18
+ "must provide logits、completion_ids and advantages"
19
+ )
20
+
21
+ return GrpoLossFunction.apply(
22
+ logits,
23
+ old_logp,
24
+ ref_logp,
25
+ completion_ids,
26
+ advantages,
27
+ completion_mask,
28
+ temperature,
29
+ beta,
30
+ eps_low,
31
+ eps_high,
32
+ inplace,
33
+ )
34
+
35
+
36
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
37
+ """
38
+ import torch
39
+ import trl
40
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
41
+ from trl.extras.profiling import profiling_decorator
42
+
43
+ @profiling_decorator
44
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
45
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
46
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
47
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
48
+
49
+ @profiling_decorator
50
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
51
+ if return_outputs:
52
+ raise ValueError("The GRPOTrainer does not support returning outputs")
53
+ # Compute the per-token log probabilities for the model
54
+
55
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
56
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
57
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
58
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
59
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
60
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
61
+
62
+ ref_per_token_logps = inputs["ref_per_token_logps"]
63
+ advantages = inputs["advantages"]
64
+ old_per_token_logps = inputs["old_per_token_logps"]
65
+
66
+
67
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
68
+ old_per_token_logps,
69
+ ref_per_token_logps,
70
+ completion_ids,
71
+ advantages,
72
+ completion_mask,
73
+ self.temperature,
74
+ self.beta,
75
+ self.epsilon_low,
76
+ self.epsilon_high,)
77
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
78
+
79
+ # Log the metrics
80
+ mode = "eval" if self.control.should_evaluate else "train"
81
+
82
+ if self.beta != 0.0:
83
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
84
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
85
+
86
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
87
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
88
+ return loss
89
+
90
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
91
+ trl.GRPOTrainer.compute_loss = compute_loss
92
+ trigger = None
93
+ """
94
+
95
+ # add this line at the first line of grpo.py in open-r1
96
+ """
97
+ from liger_kernel.transformers.grpo_loss import trigger
98
+ """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250517045825
3
+ Version: 0.5.9.dev20250519015630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -17,11 +17,12 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmM
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
20
- liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
20
+ liger_kernel/ops/dyt.py,sha256=Y180EIvtUc2z83mhyub0EVOCQHJmWX3JnscqkOJqswk,5467
21
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
22
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
23
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
25
+ liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
25
26
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
26
27
  liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
27
28
  liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
@@ -37,13 +38,14 @@ liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-
37
38
  liger_kernel/transformers/__init__.py,sha256=0KX0rxyy0E_uNWVE0PSTzEVzKqc5KdFHtvdHhJm23Kk,7077
38
39
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
39
40
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
40
- liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
41
+ liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
41
42
  liger_kernel/transformers/functional.py,sha256=2YBfvtdU1GRZuRpJhHgJXeGYa1RvmO6-qQvrKQrLJK4,5259
42
43
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
43
44
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
44
45
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
45
46
  liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
46
47
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
48
+ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
47
49
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
48
50
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
49
51
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
@@ -79,9 +81,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
79
81
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
80
82
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
81
83
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
82
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
83
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/METADATA,sha256=BuBSgxsfevOG0TuitK7JKRWg-2bmPfOoZZoDZwm0AeQ,23970
84
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
85
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
86
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
87
- liger_kernel_nightly-0.5.9.dev20250517045825.dist-info/RECORD,,
84
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
85
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/METADATA,sha256=_HRxosGQvS3kYalXZIxjmOinoXb0PoA0kSVBH3SbuHg,23970
86
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
87
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
88
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
89
+ liger_kernel_nightly-0.5.9.dev20250519015630.dist-info/RECORD,,