liger-kernel-nightly 0.5.2.dev20241228022953__py3-none-any.whl → 0.5.2.dev20241229131950__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -30,11 +30,14 @@ def liger_cross_entropy_kernel(
30
30
  X_stride,
31
31
  Y_ptr,
32
32
  Y_stride,
33
+ weight_ptr,
33
34
  loss_ptr,
34
35
  z_loss_ptr,
35
36
  loss_stride,
36
37
  n_cols,
37
38
  n_non_ignore,
39
+ sum_non_ignore_weight,
40
+ weight_sum,
38
41
  ignore_index,
39
42
  lse_square_scale: tl.constexpr,
40
43
  label_smoothing: tl.constexpr,
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
42
45
  softcap,
43
46
  RETURN_Z_LOSS: tl.constexpr,
44
47
  BLOCK_SIZE: tl.constexpr,
48
+ HAS_WEIGHT: tl.constexpr,
45
49
  HAS_SOFTCAPPING: tl.constexpr,
46
50
  ):
47
51
  """
@@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
53
57
  X_stride (int): The stride of the input tensor.
54
58
  Y_ptr: Pointer to target tensor.
55
59
  Y_stride (int): The stride of the target tensor.
60
+ weight_ptr: Pointer to weight tensor.
56
61
  loss_ptr: Pointer to tensor to store the loss.
57
62
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
58
63
  loss_stride (int): The stride of the loss tensor.
59
64
  n_cols (int): The number of columns in the input tensor.
60
- n_non_ignore (int): The number of non-ignored elements in the batch.
65
+ n_non_ignore (flaot): The number of non-ignored elements in the batch.
66
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
67
+ weight_sum (float): The sum of weight tensor.
61
68
  ignore_index (int): The index to ignore in the target.
62
69
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
63
70
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
64
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
65
71
  reduction (str): The string for the reduction to apply
66
72
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
73
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
67
74
  BLOCK_SIZE (int): The block size for Triton operations.
75
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
68
76
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
69
77
  """
70
78
 
@@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
89
97
  loss_ptr += program_id * loss_stride
90
98
  z_loss_ptr += program_id * loss_stride
91
99
 
100
+ if HAS_WEIGHT:
101
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
102
+
92
103
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
93
104
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
94
105
 
@@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
117
128
  block_max = tl.max(X_block)
118
129
  if label_smoothing > 0:
119
130
  # scale X beforehand to avoid overflow
120
- scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
131
+ if HAS_WEIGHT:
132
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
133
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
134
+ else:
135
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
121
136
  m_new = tl.maximum(m, block_max)
122
137
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
123
138
  m = m_new
@@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
153
168
  if HAS_SOFTCAPPING:
154
169
  intermediate = tanh(X_block / softcap)
155
170
  X_block = softcap * intermediate
156
- # softmax(x_i)
157
- X_block = tl.exp(X_block - m) / d
158
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
159
- X_block += 2 * lse_square_scale * lse * X_block
160
- # smoothing term
161
- X_block += -eps
162
- # special handle dx_y
163
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
164
- # reduction scale
165
- if reduction == "mean":
166
- X_block = X_block / (n_non_ignore)
167
- # chain rule
171
+
172
+ if not HAS_WEIGHT:
173
+ # softmax(x_i)
174
+ X_block = tl.exp(X_block - m) / d
175
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
176
+ X_block += 2 * lse_square_scale * lse * X_block
177
+ # smoothing term
178
+ X_block += -eps
179
+ # special handle dx_y
180
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
181
+ # reduction scale
182
+ if reduction == "mean":
183
+ X_block = X_block / n_non_ignore
184
+ else:
185
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
186
+ softmax_X = tl.exp(X_block - m) / d
187
+ # derivative of original_loss
188
+ dloss_ori = (1 - label_smoothing) * softmax_X
189
+ # specially handle dx_y
190
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
191
+ dloss_ori = dloss_ori * weight_y
192
+ # derivative of smooth_loss
193
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
194
+ # derivative of z-loss
195
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
196
+ # reduction scale
197
+ if reduction == "mean":
198
+ dloss_ori = dloss_ori / sum_non_ignore_weight
199
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
200
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
201
+ dz_loss = dz_loss / n_non_ignore
202
+ # derivative of total_loss
203
+ X_block = dloss_ori + dloss_smooth + dz_loss
204
+
205
+ # chain rule softcapping
168
206
  # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
169
207
  if HAS_SOFTCAPPING:
170
208
  X_block = X_block * (1 - intermediate * intermediate)
@@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
183
221
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
184
222
  # So we can safely calculate log (softmax(X_y)) without overflow
185
223
  loss = lse - ori_X_y
224
+ if HAS_WEIGHT:
225
+ loss = weight_y * loss
186
226
 
187
227
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
188
228
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
@@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
193
233
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
194
234
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
195
235
  if label_smoothing > 0:
196
- smooth_loss = scaled_x_sum + label_smoothing * lse
236
+ if HAS_WEIGHT:
237
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
238
+ else:
239
+ smooth_loss = scaled_x_sum + label_smoothing * lse
197
240
  loss = loss * (1 - label_smoothing) + smooth_loss
198
241
 
199
242
  # An auxiliary loss, z_loss
200
243
  # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
201
244
  z_loss = lse_square_scale * lse * lse
202
- loss += z_loss
203
245
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
204
246
  if reduction == "mean":
247
+ if HAS_WEIGHT:
248
+ loss = loss / sum_non_ignore_weight
249
+ else:
250
+ loss = loss / n_non_ignore
251
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
205
252
  z_loss = z_loss / n_non_ignore
206
- loss = loss / n_non_ignore
253
+ loss += z_loss
207
254
 
208
255
  tl.store(loss_ptr, loss)
209
256
  if RETURN_Z_LOSS == _TRUE:
@@ -225,6 +272,7 @@ _bool_to_return_z_loss = {
225
272
  def cross_entropy_forward(
226
273
  _input,
227
274
  target,
275
+ weight,
228
276
  ignore_index,
229
277
  lse_square_scale,
230
278
  label_smoothing,
@@ -250,7 +298,20 @@ def cross_entropy_forward(
250
298
  else:
251
299
  z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
252
300
 
253
- n_non_ignore = (target != ignore_index).sum().item()
301
+ target_mask = target != ignore_index
302
+ n_non_ignore = target_mask.sum().item()
303
+ sum_non_ignore_weight = n_non_ignore
304
+ weight_sum = 0.0
305
+ if weight is not None:
306
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
307
+ assert torch.is_floating_point(
308
+ weight
309
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
310
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
311
+ weight_sum = weight.sum().item()
312
+ # ensure weight is contiguous
313
+ if weight.stride(-1) != 1:
314
+ weight = weight.contiguous()
254
315
 
255
316
  # ensure _input and target are contiguous in the last dimension
256
317
  if _input.stride(-1) != 1:
@@ -264,18 +325,22 @@ def cross_entropy_forward(
264
325
  X_stride=_input.stride(-2),
265
326
  Y_ptr=target,
266
327
  Y_stride=target.stride(-1), # always 1
328
+ weight_ptr=weight if weight is not None else _input, # dummy if None
267
329
  loss_ptr=loss_1d,
268
330
  z_loss_ptr=z_loss_1d,
269
331
  loss_stride=loss_1d.stride(-1), # always 1
270
332
  n_cols=V,
271
333
  n_non_ignore=n_non_ignore,
334
+ sum_non_ignore_weight=sum_non_ignore_weight,
272
335
  ignore_index=ignore_index,
336
+ weight_sum=weight_sum,
273
337
  lse_square_scale=lse_square_scale,
274
338
  label_smoothing=label_smoothing,
275
339
  reduction=reduction,
276
340
  softcap=softcap if softcap is not None else 0.0,
277
341
  RETURN_Z_LOSS=return_z_loss,
278
342
  BLOCK_SIZE=BLOCK_SIZE,
343
+ HAS_WEIGHT=True if weight is not None else False,
279
344
  HAS_SOFTCAPPING=True if softcap is not None else False,
280
345
  # TODO: 32 seems to give the best performance
281
346
  # Performance is quite sensitive to num_warps
@@ -327,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
327
392
  ctx,
328
393
  _input: torch.Tensor,
329
394
  target: torch.Tensor,
395
+ weight: Optional[torch.FloatTensor],
330
396
  ignore_index: int = -100,
331
397
  lse_square_scale: float = 0.0,
332
398
  label_smoothing: float = 0.0,
@@ -341,6 +407,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
341
407
  ctx : The context object.
342
408
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
343
409
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
410
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
344
411
  ignore_index (int): The index to ignore in the target.
345
412
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
346
413
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
@@ -354,6 +421,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
354
421
  loss, z_loss, _input = cross_entropy_forward(
355
422
  _input,
356
423
  target,
424
+ weight,
357
425
  ignore_index,
358
426
  lse_square_scale,
359
427
  label_smoothing,
@@ -395,4 +463,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
463
  None,
396
464
  None,
397
465
  None,
466
+ None,
398
467
  )
@@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward(
17
17
  _input,
18
18
  weight,
19
19
  target,
20
+ ce_weight=None,
20
21
  bias=None,
21
22
  ignore_index=-100,
22
23
  lse_square_scale=0.0,
@@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward(
47
48
  # we use fp32 for loss accumulator
48
49
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
49
50
 
50
- # NOTE: skip .item() here to avoid CUDA synchronization
51
- total_n_non_ignore = (target != ignore_index).sum()
51
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
52
+ target_mask = target != ignore_index
53
+ total_n_non_ignore = target_mask.sum().item()
54
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
55
+ ce_weight_sum = 0.0
56
+ if ce_weight is not None:
57
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
58
+ assert torch.is_floating_point(
59
+ ce_weight
60
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
61
+ total_sum_non_ignore_ce_weight = (
62
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
63
+ )
64
+ ce_weight_sum = ce_weight.sum().item()
65
+ if ce_weight.stride(-1) != 1:
66
+ ce_weight = ce_weight.contiguous()
52
67
 
53
68
  for chunk_id in range(num_chunks):
54
69
  start_idx = chunk_id * chunk_size
@@ -59,13 +74,13 @@ def fused_linear_cross_entropy_forward(
59
74
  logits_chunk = _input_chunk @ weight.t() # chunk_size x V
60
75
  if bias is not None:
61
76
  logits_chunk = logits_chunk + bias
77
+
62
78
  target_chunk = target[start_idx:end_idx] # chunk_size,
63
79
 
64
80
  n_rows = logits_chunk.shape[0]
65
81
 
66
82
  # unreduced loss
67
83
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
68
- n_non_ignore = (target_chunk != ignore_index).sum().item()
69
84
 
70
85
  # ensure _input and target are contiguous
71
86
  logits_chunk = logits_chunk.contiguous()
@@ -77,45 +92,40 @@ def fused_linear_cross_entropy_forward(
77
92
  X_stride=logits_chunk.stride(-2),
78
93
  Y_ptr=target_chunk,
79
94
  Y_stride=target_chunk.stride(-1), # always 1
95
+ weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
80
96
  loss_ptr=loss_1d_slice,
81
97
  z_loss_ptr=loss_1d_slice, # dummy ptr, not used
82
98
  loss_stride=loss_1d_slice.stride(-1), # always 1
83
99
  n_cols=V,
84
- n_non_ignore=n_non_ignore,
100
+ n_non_ignore=total_n_non_ignore,
101
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
102
+ weight_sum=ce_weight_sum,
85
103
  ignore_index=ignore_index,
86
104
  lse_square_scale=lse_square_scale,
87
105
  label_smoothing=label_smoothing,
88
106
  reduction=reduction,
89
107
  softcap=softcap if softcap is not None else 0.0,
90
108
  RETURN_Z_LOSS=0, # False
109
+ HAS_WEIGHT=True if ce_weight is not None else False,
91
110
  HAS_SOFTCAPPING=True if softcap is not None else False,
92
111
  BLOCK_SIZE=BLOCK_SIZE,
93
112
  num_warps=32 if not is_hip() else 16,
94
113
  )
95
114
 
96
- # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
97
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
98
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
99
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
100
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
101
-
102
- if reduction == "mean":
103
- alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
104
- else:
105
- alpha = 1.0
106
-
107
- loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
108
- grad_logits_chunk = logits_chunk * alpha # chunk_size x V
115
+ loss_1d[start_idx:end_idx] = loss_1d_slice
116
+ grad_logits_chunk = logits_chunk # chunk_size x V
109
117
 
110
118
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
111
119
 
112
120
  if grad_weight is not None:
113
121
  torch.addmm(
114
122
  input=grad_weight,
115
- mat1=logits_chunk.t(),
123
+ mat1=logits_chunk.t().to(
124
+ _input_chunk.dtype
125
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
116
126
  mat2=_input_chunk,
117
127
  out=grad_weight,
118
- alpha=alpha,
128
+ alpha=1.0,
119
129
  beta=1.0,
120
130
  )
121
131
 
@@ -124,7 +134,7 @@ def fused_linear_cross_entropy_forward(
124
134
  input=grad_bias,
125
135
  other=logits_chunk.sum(dim=0),
126
136
  out=grad_bias,
127
- alpha=alpha,
137
+ alpha=1.0,
128
138
  )
129
139
 
130
140
  if reduction == "none":
@@ -190,6 +200,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
190
200
  weight,
191
201
  target,
192
202
  bias=None,
203
+ ce_weight=None,
193
204
  ignore_index=-100,
194
205
  lse_square_scale=0.0,
195
206
  label_smoothing=0.0,
@@ -209,21 +220,23 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
209
220
  target: (B*T) where each value is in [0, V-1]
210
221
  weight: (V, H) where V is the number of classes
211
222
  bias: (V) where V is the number of classes
223
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
212
224
  ignore_index: the index to ignore in the target
213
225
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
214
226
  reduction: reduction to apply
215
227
  """
216
228
 
217
229
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
218
- _input,
219
- weight,
220
- target,
221
- bias,
222
- ignore_index,
223
- lse_square_scale,
224
- label_smoothing,
225
- reduction,
226
- softcap,
230
+ _input=_input,
231
+ weight=weight,
232
+ target=target,
233
+ bias=bias,
234
+ ce_weight=ce_weight,
235
+ ignore_index=ignore_index,
236
+ lse_square_scale=lse_square_scale,
237
+ label_smoothing=label_smoothing,
238
+ reduction=reduction,
239
+ softcap=softcap,
227
240
  )
228
241
  # downcast to dtype and store for backward
229
242
  ctx.save_for_backward(
@@ -240,4 +253,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
240
253
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
241
254
  grad_output, grad_input, grad_weight, grad_bias
242
255
  )
243
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
256
+ return (
257
+ grad_input,
258
+ grad_weight,
259
+ None,
260
+ grad_bias,
261
+ None,
262
+ None,
263
+ None,
264
+ None,
265
+ None,
266
+ None,
267
+ )
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
8
8
  class LigerCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -28,6 +29,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
28
29
  "none",
29
30
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
31
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
32
+ self.weight = weight
31
33
  self.ignore_index = ignore_index
32
34
  self.lse_square_scale = lse_square_scale
33
35
  self.label_smoothing = label_smoothing
@@ -39,6 +41,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
39
41
  loss, z_loss = LigerCrossEntropyFunction.apply(
40
42
  _input,
41
43
  target,
44
+ self.weight,
42
45
  self.ignore_index,
43
46
  self.lse_square_scale,
44
47
  self.label_smoothing,
@@ -32,6 +32,7 @@ def liger_cross_entropy(
32
32
  loss, z_loss = LigerCrossEntropyFunction.apply(
33
33
  input,
34
34
  target,
35
+ weight,
35
36
  ignore_index,
36
37
  lse_square_scale,
37
38
  label_smoothing,
@@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy(
49
50
  weight,
50
51
  target,
51
52
  bias=None,
53
+ ce_weight=None,
52
54
  ignore_index: int = -100,
53
55
  lse_square_scale: float = 0.0,
54
56
  label_smoothing: float = 0.0,
@@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy(
60
62
  weight,
61
63
  target,
62
64
  bias,
65
+ ce_weight,
63
66
  ignore_index,
64
67
  lse_square_scale,
65
68
  label_smoothing,
@@ -8,6 +8,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEnt
8
8
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ ce_weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -24,6 +25,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
24
25
  "none",
25
26
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ self.ce_weight = ce_weight
27
29
  self.ignore_index = ignore_index
28
30
  self.lse_square_scale = lse_square_scale
29
31
  self.label_smoothing = label_smoothing
@@ -36,6 +38,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
36
38
  lin_weight,
37
39
  target,
38
40
  bias,
41
+ self.ce_weight,
39
42
  self.ignore_index,
40
43
  self.lse_square_scale,
41
44
  self.label_smoothing,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241228022953
3
+ Version: 0.5.2.dev20241229131950
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -11,8 +11,8 @@ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJ
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=LR0zLL8JYMhk9e22jmBxU4lwEYic3YqMAG3837yaHmM,9418
14
+ liger_kernel/ops/cross_entropy.py,sha256=4zSPzdPl-d2tB3ZOj7uRMpzI4RzZMNLUzkh6eMkH5kU,19179
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
17
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
18
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -28,9 +28,9 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
28
28
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
29
29
  liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
30
30
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
31
- liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
32
- liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
33
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
31
+ liger_kernel/transformers/cross_entropy.py,sha256=s931h9UW_tV4QMRme1HYjS_R2_C5nD6VFmZIXtjJoYo,1840
32
+ liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
33
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
34
34
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
35
35
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
36
36
  liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/METADATA,sha256=Z5fzI-xpYPtjwawEGwIw-LRJUIeY1VEdDUK9wgklR7w,21055
63
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/METADATA,sha256=iOyPsdNf1GL3Z3Ng0CS3xoOq6iiTb8eFXAMwqDT1UZM,21055
63
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/RECORD,,