liger-kernel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,8 +21,10 @@ def fused_linear_cross_entropy_forward(
21
21
  target,
22
22
  bias=None,
23
23
  ignore_index=-100,
24
+ lse_square_scale=0.0,
24
25
  label_smoothing=0.0,
25
26
  reduction="mean",
27
+ softcap=None,
26
28
  ):
27
29
  dtype = _input.dtype
28
30
  device = _input.device
@@ -86,12 +88,17 @@ def fused_linear_cross_entropy_forward(
86
88
  Y_ptr=target_chunk,
87
89
  Y_stride=target_chunk.stride(-1), # always 1
88
90
  loss_ptr=loss_1d_slice,
91
+ z_loss_ptr=loss_1d_slice, # dummy ptr, not used
89
92
  loss_stride=loss_1d_slice.stride(-1), # always 1
90
93
  n_cols=V,
91
94
  n_non_ignore=n_non_ignore,
92
95
  ignore_index=ignore_index,
96
+ lse_square_scale=lse_square_scale,
93
97
  label_smoothing=label_smoothing,
94
98
  reduction=reduction,
99
+ softcap=softcap if softcap is not None else 0.0,
100
+ RETURN_Z_LOSS=0, # False
101
+ HAS_SOFTCAPPING=True if softcap is not None else False,
95
102
  BLOCK_SIZE=BLOCK_SIZE,
96
103
  num_warps=32 if not is_hip() else 16,
97
104
  )
@@ -200,8 +207,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
200
207
  target,
201
208
  bias=None,
202
209
  ignore_index=-100,
210
+ lse_square_scale=0.0,
203
211
  label_smoothing=0.0,
204
212
  reduction="mean",
213
+ softcap=None,
205
214
  ):
206
215
  """
207
216
  Fusing the last linear layer with cross-entropy loss
@@ -220,8 +229,17 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
220
229
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
221
230
  reduction: reduction to apply
222
231
  """
232
+
223
233
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
224
- _input, weight, target, bias, ignore_index, label_smoothing, reduction
234
+ _input,
235
+ weight,
236
+ target,
237
+ bias,
238
+ ignore_index,
239
+ lse_square_scale,
240
+ label_smoothing,
241
+ reduction,
242
+ softcap,
225
243
  )
226
244
  # downcast to dtype and store for backward
227
245
  ctx.save_for_backward(
@@ -238,4 +256,4 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
238
256
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
239
257
  grad_output, grad_input, grad_weight, grad_bias
240
258
  )
241
- return (grad_input, grad_weight, None, grad_bias, None, None, None)
259
+ return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
@@ -0,0 +1,322 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import compare_version, ensure_contiguous
8
+
9
+ if compare_version("triton", operator.ge, "3.0.0"):
10
+ try:
11
+ # typical import path with dispatch available
12
+ from triton.language.extra.libdevice import rsqrt
13
+ except ModuleNotFoundError:
14
+ # for working with NGC containers
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+ MAX_FUSED_SIZE = 65536
20
+
21
+
22
+ @triton.jit
23
+ def _group_norm_forward_kernel(
24
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
25
+ Y_row_stride, # stride of each row in output
26
+ Y_col_stride, # stride of each column in output
27
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
28
+ X_row_stride, # stride of each row in input
29
+ X_col_stride, # stride of each column in input
30
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
31
+ Mean_row_stride, # stride of each row in mean
32
+ Mean_col_stride, # stride of each column in mean
33
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
34
+ RSTD_row_stride, # stride of each row in rstd
35
+ RSTD_col_stride, # stride of each column in rstd
36
+ W_ptr, # pointer to W
37
+ B_ptr, # pointer to B
38
+ hidden_size, # hidden size of X
39
+ channels_per_group, # the number of channels per group
40
+ eps,
41
+ BLOCK_SIZE: tl.constexpr,
42
+ ):
43
+ """
44
+ References:
45
+ https://nn.labml.ai/normalization/group_norm/index.html
46
+ """
47
+ batch_idx = tl.program_id(0)
48
+ group_idx = tl.program_id(1)
49
+
50
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
51
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
52
+
53
+ block_range = tl.arange(0, BLOCK_SIZE)
54
+
55
+ # Compute mean and variance using the online algorithm
56
+ s = 0.0
57
+ squared_sum = 0.0
58
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
59
+ hidden_size_offsets = i + block_range
60
+ mask = hidden_size_offsets < hidden_size
61
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
62
+ s += tl.sum(X)
63
+ # X**2
64
+ squared_sum += tl.sum(X * X)
65
+
66
+ m = s / hidden_size
67
+
68
+ # variance = E[X**2] - E[X]**2
69
+ variance = (squared_sum / hidden_size) - (m * m)
70
+
71
+ # 1/std
72
+ rstd = rsqrt(variance + eps)
73
+
74
+ # Normalize
75
+ hidden_size_per_channel = hidden_size // channels_per_group
76
+ for channel_idx in tl.range(
77
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
+ ):
79
+ W = tl.load(W_ptr + channel_idx)
80
+ B = tl.load(B_ptr + channel_idx)
81
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
82
+ hidden_size_offsets = i + block_range
83
+ mask = hidden_size_offsets < hidden_size_per_channel
84
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
85
+ Y = (X - m) * rstd * W + B
86
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
87
+
88
+ X_ptr += hidden_size_per_channel
89
+ Y_ptr += hidden_size_per_channel
90
+
91
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
92
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
93
+
94
+
95
+ @triton.jit
96
+ def _group_norm_backward_kernel(
97
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
98
+ X_row_stride, # stride of each row in input
99
+ X_col_stride, # stride of each column in input
100
+ W_ptr, # pointer to weights, shape (n_channels)
101
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
102
+ Mean_ptr_row_stride, # stride of each column in mean
103
+ Mean_ptr_col_stride, # stride of each column in mean
104
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
105
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
106
+ DW_ptr, # pointer to weights grad, shape (n_channels)
107
+ DB_ptr, # pointer to bias grad, shape (n_channels)
108
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
109
+ hidden_size: tl.constexpr, # hidden size
110
+ channels_per_group: tl.constexpr, # number of groups in group norm
111
+ BLOCK_SIZE: tl.constexpr,
112
+ dtype: tl.constexpr,
113
+ ):
114
+ """
115
+ References:
116
+ https://nn.labml.ai/normalization/group_norm/index.html
117
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
118
+
119
+ The backprop equations are the same for group_norm and layer_norm
120
+ the only difference here is that we load the Mean, Rstd corresponding to the
121
+ group we're computing gradients for and the mean and rstd are computed over n-channels
122
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
123
+
124
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
125
+ """
126
+ batch_idx = tl.program_id(0)
127
+ group_idx = tl.program_id(1)
128
+
129
+ # Move the pointers to the correct batch
130
+ X_ptr += batch_idx * X_row_stride
131
+ DX_ptr += batch_idx * X_row_stride
132
+ UPSTREAM_ptr += batch_idx * X_row_stride
133
+
134
+ # Mean and rstd are the same shape so have the same strides
135
+ mean = tl.load(
136
+ Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
+ )
138
+ rstd = tl.load(
139
+ RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
+ )
141
+
142
+ c1 = 0.0
143
+ c2 = 0.0
144
+ block_range = tl.arange(0, BLOCK_SIZE)
145
+
146
+ # We need to compute the sum terms of the backprop equations across all channels in the group
147
+ for channel_idx in range(
148
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
+ ):
150
+ dW = 0.0
151
+ dB = 0.0
152
+ # Move the pointers to the correct channel
153
+ W = tl.load(W_ptr + channel_idx)
154
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
155
+ hidden_size_offsets = i + block_range
156
+ mask = hidden_size_offsets < hidden_size
157
+ X = tl.load(
158
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
159
+ mask=mask,
160
+ other=0.0,
161
+ )
162
+ UPSTREAM_grad = tl.load(
163
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
164
+ mask=mask,
165
+ other=0.0,
166
+ )
167
+
168
+ x_hat = (X - mean) * rstd
169
+ dW += tl.sum(UPSTREAM_grad * x_hat)
170
+ dB += tl.sum(UPSTREAM_grad)
171
+
172
+ wdy = W * UPSTREAM_grad
173
+ c1 += tl.sum(x_hat * wdy)
174
+ c2 += tl.sum(wdy)
175
+
176
+ # Need to ensure additions to the same channel are atomic
177
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
178
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
179
+
180
+ N = hidden_size * channels_per_group
181
+ c1 = c1 / N
182
+ c2 = c2 / N
183
+
184
+ for channel_idx in tl.range(
185
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
+ ):
187
+ # Move the pointers to the correct channel
188
+ W = tl.load(W_ptr + channel_idx)
189
+ for i in range(0, hidden_size, BLOCK_SIZE):
190
+ hidden_size_offsets = i + block_range
191
+ mask = hidden_size_offsets < hidden_size
192
+ X = tl.load(
193
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
194
+ mask=mask,
195
+ other=0.0,
196
+ )
197
+ UPSTREAM_grad = tl.load(
198
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
199
+ mask=mask,
200
+ other=0.0,
201
+ )
202
+
203
+ x_hat = (X - mean) * rstd
204
+ wdy = W * UPSTREAM_grad
205
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
206
+ tl.store(
207
+ DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
+ )
209
+
210
+
211
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
212
+ shape = X.shape
213
+ batch_size = shape[0]
214
+ channels_per_group = num_channels // num_groups
215
+ # Reshape X so that the mean and std are computed across the groups
216
+ X = X.view(batch_size, num_groups, -1).contiguous()
217
+ hidden_size = X.shape[-1]
218
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
+ Y = torch.empty(
220
+ (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
+ )
222
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
+
225
+ _group_norm_forward_kernel[(batch_size, num_groups)](
226
+ Y,
227
+ Y.stride(0),
228
+ Y.stride(1),
229
+ X,
230
+ X.stride(0),
231
+ X.stride(1),
232
+ Mean,
233
+ Mean.stride(0),
234
+ Mean.stride(1),
235
+ RSTD,
236
+ RSTD.stride(0),
237
+ RSTD.stride(1),
238
+ W,
239
+ B,
240
+ hidden_size,
241
+ channels_per_group,
242
+ eps,
243
+ BLOCK_SIZE=BLOCK_SIZE,
244
+ )
245
+ # Return tensors in the original shape
246
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
247
+
248
+
249
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
250
+ shape = dY.shape
251
+ batch_size = shape[0]
252
+ hidden_size = dY.shape[-1]
253
+ channels_per_group = num_channels // num_groups
254
+ dY = dY.view(batch_size, num_groups, -1)
255
+ DX = torch.empty(
256
+ (batch_size, num_groups, hidden_size * channels_per_group),
257
+ dtype=X.dtype,
258
+ device=X.device,
259
+ )
260
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
261
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
262
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
263
+
264
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
265
+ _group_norm_backward_kernel[(batch_size, num_groups)](
266
+ X,
267
+ X.stride(0),
268
+ X.stride(1),
269
+ W,
270
+ Mean,
271
+ Mean.stride(0),
272
+ Mean.stride(1),
273
+ RSTD,
274
+ DX,
275
+ DW,
276
+ DB,
277
+ dY,
278
+ hidden_size,
279
+ channels_per_group,
280
+ BLOCK_SIZE=BLOCK_SIZE,
281
+ dtype=triton_dtype,
282
+ )
283
+
284
+ # Return tensors in the original shape
285
+ return DX.view(*shape), DW, DB
286
+
287
+
288
+ class LigerGroupNormFunction(torch.autograd.Function):
289
+ @staticmethod
290
+ @ensure_contiguous
291
+ def forward(
292
+ ctx,
293
+ X,
294
+ affine_scaling_weight,
295
+ affine_shifting_bias,
296
+ num_channels,
297
+ num_groups,
298
+ eps,
299
+ ):
300
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
301
+ X,
302
+ num_channels,
303
+ num_groups,
304
+ affine_scaling_weight,
305
+ affine_shifting_bias,
306
+ eps,
307
+ )
308
+ ctx.num_channels = num_channels
309
+ ctx.num_groups = num_groups
310
+ ctx.save_for_backward(
311
+ X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
+ )
313
+ return Y
314
+
315
+ @staticmethod
316
+ @ensure_contiguous
317
+ def backward(ctx, dY):
318
+ X, W, B, Mean, RSTD = ctx.saved_tensors
319
+ DX, DW, DB = group_norm_backward(
320
+ dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
+ )
322
+ return DX, DW, DB, None, None, None
@@ -116,6 +116,8 @@ def _rms_norm_forward_kernel(
116
116
  def _rms_norm_backward_kernel(
117
117
  dY_ptr,
118
118
  dY_row_stride,
119
+ dX_ptr,
120
+ dX_row_stride,
119
121
  X_ptr,
120
122
  X_row_stride,
121
123
  X_dtype: tl.constexpr,
@@ -146,6 +148,8 @@ def _rms_norm_backward_kernel(
146
148
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
149
 
148
150
  dY_ptr += row_start * dY_row_stride
151
+ dX_ptr += row_start * dX_row_stride
152
+
149
153
  X_ptr += row_start * X_row_stride
150
154
  RSTD_ptr += row_start
151
155
 
@@ -184,9 +188,10 @@ def _rms_norm_backward_kernel(
184
188
  # here X_row is already in fp32 (see previous if block)
185
189
  dW_row += dY_row * (X_row * rstd_row)
186
190
 
187
- tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
191
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
192
 
189
193
  dY_ptr += dY_row_stride
194
+ dX_ptr += dX_row_stride
190
195
  X_ptr += X_row_stride
191
196
  RSTD_ptr += RSTD_row_stride
192
197
 
@@ -251,7 +256,9 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
251
256
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
252
257
 
253
258
 
254
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
259
+ def rms_norm_backward(
260
+ dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
+ ):
255
262
  shape = dY.shape
256
263
  dim = shape[-1]
257
264
  dY = dY.view(-1, dim)
@@ -265,10 +272,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
265
272
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
266
273
  rows_per_program = math.ceil(n_rows / sm_count)
267
274
  grid = (sm_count,)
268
- # Here we use dY to store the value of dX to save memory
275
+
276
+ if in_place is True:
277
+ dX = dY
278
+ else:
279
+ dX = torch.zeros_like(dY)
280
+
269
281
  _rms_norm_backward_kernel[grid](
270
282
  dY,
271
283
  dY.stride(0),
284
+ dX,
285
+ dX.stride(0),
272
286
  X,
273
287
  X.stride(0),
274
288
  torch_to_triton_dtype[X.dtype],
@@ -286,8 +300,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
286
300
  BLOCK_SIZE=BLOCK_SIZE,
287
301
  num_warps=num_warps,
288
302
  )
289
- dX = dY.view(*shape)
303
+ dX = dX.view(*shape)
290
304
  dW = _dW.sum(dim=0).to(W.dtype)
305
+
291
306
  return dX, dW
292
307
 
293
308
 
@@ -307,11 +322,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
307
322
  - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
308
323
  - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
309
324
  - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
325
+
326
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
327
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
328
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
310
329
  """
311
330
 
312
331
  @staticmethod
313
332
  @ensure_contiguous
314
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
333
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
315
334
  """
316
335
  X: (B, T, H) or (BxT, H)
317
336
  W: (H,)
@@ -321,6 +340,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
321
340
  )
322
341
  ctx.offset = offset
323
342
  ctx.casting_mode = casting_mode
343
+ ctx.in_place = in_place
324
344
  ctx.BLOCK_SIZE = BLOCK_SIZE
325
345
  ctx.num_warps = num_warps
326
346
  ctx.save_for_backward(X, W, RSTD)
@@ -342,5 +362,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
342
362
  ctx.casting_mode,
343
363
  ctx.BLOCK_SIZE,
344
364
  ctx.num_warps,
365
+ ctx.in_place,
345
366
  )
346
- return dX, dW, None, None, None
367
+ return dX, dW, None, None, None, None
@@ -1,21 +1,53 @@
1
- from torch.nn import CrossEntropyLoss
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
6
 
5
7
 
6
- class LigerCrossEntropyLoss(CrossEntropyLoss):
7
- def __init__(self, *args, **kwargs):
8
- super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
- assert (self.label_smoothing >= 0) and (
10
- self.label_smoothing <= 1
11
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
12
- assert self.reduction in {
8
+ class LigerCrossEntropyLoss(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ ignore_index: int = -100,
12
+ lse_square_scale: float = 0.0,
13
+ label_smoothing: float = 0.0,
14
+ reduction: str = "mean",
15
+ softcap: Optional[float] = None,
16
+ return_z_loss: bool = False,
17
+ ):
18
+ super().__init__()
19
+ assert (label_smoothing >= 0) and (
20
+ label_smoothing <= 1
21
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
+ assert (label_smoothing >= 0) and (
23
+ label_smoothing <= 1
24
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
25
+ assert reduction in {
13
26
  "mean",
14
27
  "sum",
15
28
  "none",
16
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
29
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
+ assert (
31
+ softcap is None or softcap > 0
32
+ ), f"softcap must greater than 0.0 or None. Got: {softcap}"
33
+ self.ignore_index = ignore_index
34
+ self.lse_square_scale = lse_square_scale
35
+ self.label_smoothing = label_smoothing
36
+ self.reduction = reduction
37
+ self.softcap = softcap
38
+ self.return_z_loss = return_z_loss
17
39
 
18
- def forward(self, _input, target):
19
- return LigerCrossEntropyFunction.apply(
20
- _input, target, self.ignore_index, self.label_smoothing, self.reduction
40
+ def forward(self, _input: torch.Tensor, target: torch.Tensor):
41
+ loss, z_loss = LigerCrossEntropyFunction.apply(
42
+ _input,
43
+ target,
44
+ self.ignore_index,
45
+ self.lse_square_scale,
46
+ self.label_smoothing,
47
+ self.reduction,
48
+ self.softcap,
49
+ self.return_z_loss,
21
50
  )
51
+ if not self.return_z_loss:
52
+ return loss
53
+ return loss, z_loss
@@ -1,9 +1,12 @@
1
+ from typing import Optional
2
+
1
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
4
  from liger_kernel.ops.fused_linear_cross_entropy import (
3
5
  LigerFusedLinearCrossEntropyFunction,
4
6
  )
5
7
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
8
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction
7
10
  from liger_kernel.ops.jsd import LigerJSDFunction
8
11
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
9
12
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
@@ -12,7 +15,6 @@ from liger_kernel.ops.rope import LigerRopeFunction
12
15
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
13
16
 
14
17
  liger_swiglu = LigerSiLUMulFunction.apply
15
- liger_cross_entropy = LigerCrossEntropyFunction.apply
16
18
  liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
17
19
  liger_geglu = LigerGELUMulFunction.apply
18
20
  liger_rms_norm = LigerRMSNormFunction.apply
@@ -21,3 +23,34 @@ liger_layer_norm = LigerLayerNormFunction.apply
21
23
  liger_kl_div = LigerKLDivLossFunction.apply
22
24
  liger_jsd = LigerJSDFunction.apply
23
25
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
26
+ liger_group_norm = LigerGroupNormFunction.apply
27
+
28
+
29
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
30
+ # `weight` and `size_average` are placeholders and not implemented yet
31
+ def liger_cross_entropy(
32
+ input,
33
+ target,
34
+ weight=None,
35
+ size_average=None,
36
+ ignore_index: int = -100,
37
+ reduce=None,
38
+ reduction: str = "mean",
39
+ label_smoothing: float = 0.0,
40
+ lse_square_scale: float = 0.0,
41
+ softcap: Optional[float] = None,
42
+ return_z_loss: bool = False,
43
+ ):
44
+ loss, z_loss = LigerCrossEntropyFunction.apply(
45
+ input,
46
+ target,
47
+ ignore_index,
48
+ lse_square_scale,
49
+ label_smoothing,
50
+ reduction,
51
+ softcap,
52
+ return_z_loss,
53
+ )
54
+ if not return_z_loss:
55
+ return loss
56
+ return loss, z_loss
@@ -1,13 +1,38 @@
1
- from torch.nn import CrossEntropyLoss
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.fused_linear_cross_entropy import (
4
6
  LigerFusedLinearCrossEntropyFunction,
5
7
  )
6
8
 
7
9
 
8
- class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
- def __init__(self, *args, **kwargs):
10
- super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
10
+ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ ignore_index: int = -100,
14
+ lse_square_scale: float = 0.0,
15
+ label_smoothing: float = 0.0,
16
+ reduction: str = "mean",
17
+ softcap: Optional[float] = None,
18
+ ):
19
+ super().__init__()
20
+ assert (label_smoothing >= 0) and (
21
+ label_smoothing <= 1
22
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23
+ assert reduction in {
24
+ "mean",
25
+ "sum",
26
+ "none",
27
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28
+ assert (
29
+ softcap is None or softcap > 0
30
+ ), f"softcap must greater than 0.0 or None. Got: {softcap}"
31
+ self.ignore_index = ignore_index
32
+ self.lse_square_scale = lse_square_scale
33
+ self.label_smoothing = label_smoothing
34
+ self.reduction = reduction
35
+ self.softcap = softcap
11
36
 
12
37
  def forward(self, lin_weight, _input, target, bias=None):
13
38
  return LigerFusedLinearCrossEntropyFunction.apply(
@@ -16,6 +41,8 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
16
41
  target,
17
42
  bias,
18
43
  self.ignore_index,
44
+ self.lse_square_scale,
19
45
  self.label_smoothing,
20
46
  self.reduction,
47
+ self.softcap,
21
48
  )