liger-kernel-nightly 0.5.2.dev20241228022953__py3-none-any.whl → 0.5.2.dev20241229131950__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,11 +30,14 @@ def liger_cross_entropy_kernel(
30
30
  X_stride,
31
31
  Y_ptr,
32
32
  Y_stride,
33
+ weight_ptr,
33
34
  loss_ptr,
34
35
  z_loss_ptr,
35
36
  loss_stride,
36
37
  n_cols,
37
38
  n_non_ignore,
39
+ sum_non_ignore_weight,
40
+ weight_sum,
38
41
  ignore_index,
39
42
  lse_square_scale: tl.constexpr,
40
43
  label_smoothing: tl.constexpr,
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
42
45
  softcap,
43
46
  RETURN_Z_LOSS: tl.constexpr,
44
47
  BLOCK_SIZE: tl.constexpr,
48
+ HAS_WEIGHT: tl.constexpr,
45
49
  HAS_SOFTCAPPING: tl.constexpr,
46
50
  ):
47
51
  """
@@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
53
57
  X_stride (int): The stride of the input tensor.
54
58
  Y_ptr: Pointer to target tensor.
55
59
  Y_stride (int): The stride of the target tensor.
60
+ weight_ptr: Pointer to weight tensor.
56
61
  loss_ptr: Pointer to tensor to store the loss.
57
62
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
58
63
  loss_stride (int): The stride of the loss tensor.
59
64
  n_cols (int): The number of columns in the input tensor.
60
- n_non_ignore (int): The number of non-ignored elements in the batch.
65
+ n_non_ignore (flaot): The number of non-ignored elements in the batch.
66
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
67
+ weight_sum (float): The sum of weight tensor.
61
68
  ignore_index (int): The index to ignore in the target.
62
69
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
63
70
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
64
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
65
71
  reduction (str): The string for the reduction to apply
66
72
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
73
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
67
74
  BLOCK_SIZE (int): The block size for Triton operations.
75
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
68
76
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
69
77
  """
70
78
 
@@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
89
97
  loss_ptr += program_id * loss_stride
90
98
  z_loss_ptr += program_id * loss_stride
91
99
 
100
+ if HAS_WEIGHT:
101
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
102
+
92
103
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
93
104
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
94
105
 
@@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
117
128
  block_max = tl.max(X_block)
118
129
  if label_smoothing > 0:
119
130
  # scale X beforehand to avoid overflow
120
- scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
131
+ if HAS_WEIGHT:
132
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
133
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
134
+ else:
135
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
121
136
  m_new = tl.maximum(m, block_max)
122
137
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
123
138
  m = m_new
@@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
153
168
  if HAS_SOFTCAPPING:
154
169
  intermediate = tanh(X_block / softcap)
155
170
  X_block = softcap * intermediate
156
- # softmax(x_i)
157
- X_block = tl.exp(X_block - m) / d
158
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
159
- X_block += 2 * lse_square_scale * lse * X_block
160
- # smoothing term
161
- X_block += -eps
162
- # special handle dx_y
163
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
164
- # reduction scale
165
- if reduction == "mean":
166
- X_block = X_block / (n_non_ignore)
167
- # chain rule
171
+
172
+ if not HAS_WEIGHT:
173
+ # softmax(x_i)
174
+ X_block = tl.exp(X_block - m) / d
175
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
176
+ X_block += 2 * lse_square_scale * lse * X_block
177
+ # smoothing term
178
+ X_block += -eps
179
+ # special handle dx_y
180
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
181
+ # reduction scale
182
+ if reduction == "mean":
183
+ X_block = X_block / n_non_ignore
184
+ else:
185
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
186
+ softmax_X = tl.exp(X_block - m) / d
187
+ # derivative of original_loss
188
+ dloss_ori = (1 - label_smoothing) * softmax_X
189
+ # specially handle dx_y
190
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
191
+ dloss_ori = dloss_ori * weight_y
192
+ # derivative of smooth_loss
193
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
194
+ # derivative of z-loss
195
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
196
+ # reduction scale
197
+ if reduction == "mean":
198
+ dloss_ori = dloss_ori / sum_non_ignore_weight
199
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
200
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
201
+ dz_loss = dz_loss / n_non_ignore
202
+ # derivative of total_loss
203
+ X_block = dloss_ori + dloss_smooth + dz_loss
204
+
205
+ # chain rule softcapping
168
206
  # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
169
207
  if HAS_SOFTCAPPING:
170
208
  X_block = X_block * (1 - intermediate * intermediate)
@@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
183
221
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
184
222
  # So we can safely calculate log (softmax(X_y)) without overflow
185
223
  loss = lse - ori_X_y
224
+ if HAS_WEIGHT:
225
+ loss = weight_y * loss
186
226
 
187
227
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
188
228
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
@@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
193
233
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
194
234
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
195
235
  if label_smoothing > 0:
196
- smooth_loss = scaled_x_sum + label_smoothing * lse
236
+ if HAS_WEIGHT:
237
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
238
+ else:
239
+ smooth_loss = scaled_x_sum + label_smoothing * lse
197
240
  loss = loss * (1 - label_smoothing) + smooth_loss
198
241
 
199
242
  # An auxiliary loss, z_loss
200
243
  # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
201
244
  z_loss = lse_square_scale * lse * lse
202
- loss += z_loss
203
245
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
204
246
  if reduction == "mean":
247
+ if HAS_WEIGHT:
248
+ loss = loss / sum_non_ignore_weight
249
+ else:
250
+ loss = loss / n_non_ignore
251
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
205
252
  z_loss = z_loss / n_non_ignore
206
- loss = loss / n_non_ignore
253
+ loss += z_loss
207
254
 
208
255
  tl.store(loss_ptr, loss)
209
256
  if RETURN_Z_LOSS == _TRUE:
@@ -225,6 +272,7 @@ _bool_to_return_z_loss = {
225
272
  def cross_entropy_forward(
226
273
  _input,
227
274
  target,
275
+ weight,
228
276
  ignore_index,
229
277
  lse_square_scale,
230
278
  label_smoothing,
@@ -250,7 +298,20 @@ def cross_entropy_forward(
250
298
  else:
251
299
  z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
252
300
 
253
- n_non_ignore = (target != ignore_index).sum().item()
301
+ target_mask = target != ignore_index
302
+ n_non_ignore = target_mask.sum().item()
303
+ sum_non_ignore_weight = n_non_ignore
304
+ weight_sum = 0.0
305
+ if weight is not None:
306
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
307
+ assert torch.is_floating_point(
308
+ weight
309
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
310
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
311
+ weight_sum = weight.sum().item()
312
+ # ensure weight is contiguous
313
+ if weight.stride(-1) != 1:
314
+ weight = weight.contiguous()
254
315
 
255
316
  # ensure _input and target are contiguous in the last dimension
256
317
  if _input.stride(-1) != 1:
@@ -264,18 +325,22 @@ def cross_entropy_forward(
264
325
  X_stride=_input.stride(-2),
265
326
  Y_ptr=target,
266
327
  Y_stride=target.stride(-1), # always 1
328
+ weight_ptr=weight if weight is not None else _input, # dummy if None
267
329
  loss_ptr=loss_1d,
268
330
  z_loss_ptr=z_loss_1d,
269
331
  loss_stride=loss_1d.stride(-1), # always 1
270
332
  n_cols=V,
271
333
  n_non_ignore=n_non_ignore,
334
+ sum_non_ignore_weight=sum_non_ignore_weight,
272
335
  ignore_index=ignore_index,
336
+ weight_sum=weight_sum,
273
337
  lse_square_scale=lse_square_scale,
274
338
  label_smoothing=label_smoothing,
275
339
  reduction=reduction,
276
340
  softcap=softcap if softcap is not None else 0.0,
277
341
  RETURN_Z_LOSS=return_z_loss,
278
342
  BLOCK_SIZE=BLOCK_SIZE,
343
+ HAS_WEIGHT=True if weight is not None else False,
279
344
  HAS_SOFTCAPPING=True if softcap is not None else False,
280
345
  # TODO: 32 seems to give the best performance
281
346
  # Performance is quite sensitive to num_warps
@@ -327,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
327
392
  ctx,
328
393
  _input: torch.Tensor,
329
394
  target: torch.Tensor,
395
+ weight: Optional[torch.FloatTensor],
330
396
  ignore_index: int = -100,
331
397
  lse_square_scale: float = 0.0,
332
398
  label_smoothing: float = 0.0,
@@ -341,6 +407,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
341
407
  ctx : The context object.
342
408
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
343
409
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
410
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
344
411
  ignore_index (int): The index to ignore in the target.
345
412
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
346
413
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
@@ -354,6 +421,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
354
421
  loss, z_loss, _input = cross_entropy_forward(
355
422
  _input,
356
423
  target,
424
+ weight,
357
425
  ignore_index,
358
426
  lse_square_scale,
359
427
  label_smoothing,
@@ -395,4 +463,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
463
  None,
396
464
  None,
397
465
  None,
466
+ None,
398
467
  )
@@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward(
17
17
  _input,
18
18
  weight,
19
19
  target,
20
+ ce_weight=None,
20
21
  bias=None,
21
22
  ignore_index=-100,
22
23
  lse_square_scale=0.0,
@@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward(
47
48
  # we use fp32 for loss accumulator
48
49
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
49
50
 
50
- # NOTE: skip .item() here to avoid CUDA synchronization
51
- total_n_non_ignore = (target != ignore_index).sum()
51
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
52
+ target_mask = target != ignore_index
53
+ total_n_non_ignore = target_mask.sum().item()
54
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
55
+ ce_weight_sum = 0.0
56
+ if ce_weight is not None:
57
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
58
+ assert torch.is_floating_point(
59
+ ce_weight
60
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
61
+ total_sum_non_ignore_ce_weight = (
62
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
63
+ )
64
+ ce_weight_sum = ce_weight.sum().item()
65
+ if ce_weight.stride(-1) != 1:
66
+ ce_weight = ce_weight.contiguous()
52
67
 
53
68
  for chunk_id in range(num_chunks):
54
69
  start_idx = chunk_id * chunk_size
@@ -59,13 +74,13 @@ def fused_linear_cross_entropy_forward(
59
74
  logits_chunk = _input_chunk @ weight.t() # chunk_size x V
60
75
  if bias is not None:
61
76
  logits_chunk = logits_chunk + bias
77
+
62
78
  target_chunk = target[start_idx:end_idx] # chunk_size,
63
79
 
64
80
  n_rows = logits_chunk.shape[0]
65
81
 
66
82
  # unreduced loss
67
83
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
68
- n_non_ignore = (target_chunk != ignore_index).sum().item()
69
84
 
70
85
  # ensure _input and target are contiguous
71
86
  logits_chunk = logits_chunk.contiguous()
@@ -77,45 +92,40 @@ def fused_linear_cross_entropy_forward(
77
92
  X_stride=logits_chunk.stride(-2),
78
93
  Y_ptr=target_chunk,
79
94
  Y_stride=target_chunk.stride(-1), # always 1
95
+ weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
80
96
  loss_ptr=loss_1d_slice,
81
97
  z_loss_ptr=loss_1d_slice, # dummy ptr, not used
82
98
  loss_stride=loss_1d_slice.stride(-1), # always 1
83
99
  n_cols=V,
84
- n_non_ignore=n_non_ignore,
100
+ n_non_ignore=total_n_non_ignore,
101
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
102
+ weight_sum=ce_weight_sum,
85
103
  ignore_index=ignore_index,
86
104
  lse_square_scale=lse_square_scale,
87
105
  label_smoothing=label_smoothing,
88
106
  reduction=reduction,
89
107
  softcap=softcap if softcap is not None else 0.0,
90
108
  RETURN_Z_LOSS=0, # False
109
+ HAS_WEIGHT=True if ce_weight is not None else False,
91
110
  HAS_SOFTCAPPING=True if softcap is not None else False,
92
111
  BLOCK_SIZE=BLOCK_SIZE,
93
112
  num_warps=32 if not is_hip() else 16,
94
113
  )
95
114
 
96
- # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
97
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
98
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
99
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
100
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
101
-
102
- if reduction == "mean":
103
- alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
104
- else:
105
- alpha = 1.0
106
-
107
- loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
108
- grad_logits_chunk = logits_chunk * alpha # chunk_size x V
115
+ loss_1d[start_idx:end_idx] = loss_1d_slice
116
+ grad_logits_chunk = logits_chunk # chunk_size x V
109
117
 
110
118
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
111
119
 
112
120
  if grad_weight is not None:
113
121
  torch.addmm(
114
122
  input=grad_weight,
115
- mat1=logits_chunk.t(),
123
+ mat1=logits_chunk.t().to(
124
+ _input_chunk.dtype
125
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
116
126
  mat2=_input_chunk,
117
127
  out=grad_weight,
118
- alpha=alpha,
128
+ alpha=1.0,
119
129
  beta=1.0,
120
130
  )
121
131
 
@@ -124,7 +134,7 @@ def fused_linear_cross_entropy_forward(
124
134
  input=grad_bias,
125
135
  other=logits_chunk.sum(dim=0),
126
136
  out=grad_bias,
127
- alpha=alpha,
137
+ alpha=1.0,
128
138
  )
129
139
 
130
140
  if reduction == "none":
@@ -190,6 +200,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
190
200
  weight,
191
201
  target,
192
202
  bias=None,
203
+ ce_weight=None,
193
204
  ignore_index=-100,
194
205
  lse_square_scale=0.0,
195
206
  label_smoothing=0.0,
@@ -209,21 +220,23 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
209
220
  target: (B*T) where each value is in [0, V-1]
210
221
  weight: (V, H) where V is the number of classes
211
222
  bias: (V) where V is the number of classes
223
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
212
224
  ignore_index: the index to ignore in the target
213
225
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
214
226
  reduction: reduction to apply
215
227
  """
216
228
 
217
229
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
218
- _input,
219
- weight,
220
- target,
221
- bias,
222
- ignore_index,
223
- lse_square_scale,
224
- label_smoothing,
225
- reduction,
226
- softcap,
230
+ _input=_input,
231
+ weight=weight,
232
+ target=target,
233
+ bias=bias,
234
+ ce_weight=ce_weight,
235
+ ignore_index=ignore_index,
236
+ lse_square_scale=lse_square_scale,
237
+ label_smoothing=label_smoothing,
238
+ reduction=reduction,
239
+ softcap=softcap,
227
240
  )
228
241
  # downcast to dtype and store for backward
229
242
  ctx.save_for_backward(
@@ -240,4 +253,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
240
253
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
241
254
  grad_output, grad_input, grad_weight, grad_bias
242
255
  )
243
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
256
+ return (
257
+ grad_input,
258
+ grad_weight,
259
+ None,
260
+ grad_bias,
261
+ None,
262
+ None,
263
+ None,
264
+ None,
265
+ None,
266
+ None,
267
+ )
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
8
8
  class LigerCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -28,6 +29,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
28
29
  "none",
29
30
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
31
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
32
+ self.weight = weight
31
33
  self.ignore_index = ignore_index
32
34
  self.lse_square_scale = lse_square_scale
33
35
  self.label_smoothing = label_smoothing
@@ -39,6 +41,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
39
41
  loss, z_loss = LigerCrossEntropyFunction.apply(
40
42
  _input,
41
43
  target,
44
+ self.weight,
42
45
  self.ignore_index,
43
46
  self.lse_square_scale,
44
47
  self.label_smoothing,
@@ -32,6 +32,7 @@ def liger_cross_entropy(
32
32
  loss, z_loss = LigerCrossEntropyFunction.apply(
33
33
  input,
34
34
  target,
35
+ weight,
35
36
  ignore_index,
36
37
  lse_square_scale,
37
38
  label_smoothing,
@@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy(
49
50
  weight,
50
51
  target,
51
52
  bias=None,
53
+ ce_weight=None,
52
54
  ignore_index: int = -100,
53
55
  lse_square_scale: float = 0.0,
54
56
  label_smoothing: float = 0.0,
@@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy(
60
62
  weight,
61
63
  target,
62
64
  bias,
65
+ ce_weight,
63
66
  ignore_index,
64
67
  lse_square_scale,
65
68
  label_smoothing,
@@ -8,6 +8,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEnt
8
8
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ ce_weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -24,6 +25,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
24
25
  "none",
25
26
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ self.ce_weight = ce_weight
27
29
  self.ignore_index = ignore_index
28
30
  self.lse_square_scale = lse_square_scale
29
31
  self.label_smoothing = label_smoothing
@@ -36,6 +38,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
36
38
  lin_weight,
37
39
  target,
38
40
  bias,
41
+ self.ce_weight,
39
42
  self.ignore_index,
40
43
  self.lse_square_scale,
41
44
  self.label_smoothing,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241228022953
3
+ Version: 0.5.2.dev20241229131950
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -11,8 +11,8 @@ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJ
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=LR0zLL8JYMhk9e22jmBxU4lwEYic3YqMAG3837yaHmM,9418
14
+ liger_kernel/ops/cross_entropy.py,sha256=4zSPzdPl-d2tB3ZOj7uRMpzI4RzZMNLUzkh6eMkH5kU,19179
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
17
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
18
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -28,9 +28,9 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
28
28
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
29
29
  liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
30
30
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
31
- liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
32
- liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
33
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
31
+ liger_kernel/transformers/cross_entropy.py,sha256=s931h9UW_tV4QMRme1HYjS_R2_C5nD6VFmZIXtjJoYo,1840
32
+ liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
33
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
34
34
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
35
35
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
36
36
  liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/METADATA,sha256=Z5fzI-xpYPtjwawEGwIw-LRJUIeY1VEdDUK9wgklR7w,21055
63
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/METADATA,sha256=iOyPsdNf1GL3Z3Ng0CS3xoOq6iiTb8eFXAMwqDT1UZM,21055
63
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/RECORD,,