liger-kernel-nightly 0.5.2.dev20241229035411__py3-none-any.whl → 0.5.2.dev20250101081922__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,12 +8,15 @@ from torch.nn import functional as F
8
8
 
9
9
  class LigerFusedLinearDistillationBase(torch.autograd.Function):
10
10
  @abstractmethod
11
- def distillation_loss_fn(student_logits, teacher_logits, temperature):
11
+ def distillation_loss_fn(
12
+ student_logits,
13
+ teacher_logits,
14
+ ):
12
15
  """
13
16
  Compute distillation loss.
14
17
  Args:
15
- student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
16
- teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
18
+ student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
19
+ teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
17
20
  """
18
21
  raise NotImplementedError("Distillation loss function must be implemented.")
19
22
 
@@ -65,7 +68,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
65
68
  distillation_loss_fn=None,
66
69
  full_target=None,
67
70
  ignore_index=-100,
68
- temperature=1.0,
69
71
  weight_hard_loss=0.5,
70
72
  weight_soft_loss=0.5,
71
73
  compute_ce_loss=True,
@@ -107,7 +109,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
107
109
 
108
110
  hard_loss /= full_target.shape[0]
109
111
 
110
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
112
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
111
113
  soft_loss /= full_target.shape[0]
112
114
 
113
115
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -147,10 +149,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
147
149
  teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
148
150
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
149
151
  chunk_size (int): Size of a chunk.
150
- compute_ce_loss (bool): Whether to compute CE loss.
151
152
  ignore_index (int): Index to ignore for loss computation.
152
153
  weight_hard_loss (float): Weight for hard/task loss.
153
154
  weight_soft_loss (float): Weight for soft/distillation loss.
155
+ compute_ce_loss (bool): Whether to compute CE loss.
156
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
154
157
  compiled (bool): Whether to use torch compile for chunk accumulation.
155
158
  loss_kwargs (dict): Other possible arguments that a loss function might need
156
159
  """
@@ -168,7 +171,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
168
171
  weight_hard_loss=weight_hard_loss,
169
172
  weight_soft_loss=weight_soft_loss,
170
173
  compute_ce_loss=compute_ce_loss,
171
- temperature=temperature,
172
174
  **loss_kwargs,
173
175
  )
174
176
 
@@ -223,6 +225,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
223
225
  if compiled:
224
226
  accumulate_chunk = torch.compile(accumulate_chunk)
225
227
 
228
+ student_input /= temperature
229
+ teacher_input /= temperature
230
+
226
231
  num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
227
232
  _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
228
233
  _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
@@ -30,11 +30,14 @@ def liger_cross_entropy_kernel(
30
30
  X_stride,
31
31
  Y_ptr,
32
32
  Y_stride,
33
+ weight_ptr,
33
34
  loss_ptr,
34
35
  z_loss_ptr,
35
36
  loss_stride,
36
37
  n_cols,
37
38
  n_non_ignore,
39
+ sum_non_ignore_weight,
40
+ weight_sum,
38
41
  ignore_index,
39
42
  lse_square_scale: tl.constexpr,
40
43
  label_smoothing: tl.constexpr,
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
42
45
  softcap,
43
46
  RETURN_Z_LOSS: tl.constexpr,
44
47
  BLOCK_SIZE: tl.constexpr,
48
+ HAS_WEIGHT: tl.constexpr,
45
49
  HAS_SOFTCAPPING: tl.constexpr,
46
50
  ):
47
51
  """
@@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
53
57
  X_stride (int): The stride of the input tensor.
54
58
  Y_ptr: Pointer to target tensor.
55
59
  Y_stride (int): The stride of the target tensor.
60
+ weight_ptr: Pointer to weight tensor.
56
61
  loss_ptr: Pointer to tensor to store the loss.
57
62
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
58
63
  loss_stride (int): The stride of the loss tensor.
59
64
  n_cols (int): The number of columns in the input tensor.
60
- n_non_ignore (int): The number of non-ignored elements in the batch.
65
+ n_non_ignore (flaot): The number of non-ignored elements in the batch.
66
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
67
+ weight_sum (float): The sum of weight tensor.
61
68
  ignore_index (int): The index to ignore in the target.
62
69
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
63
70
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
64
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
65
71
  reduction (str): The string for the reduction to apply
66
72
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
73
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
67
74
  BLOCK_SIZE (int): The block size for Triton operations.
75
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
68
76
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
69
77
  """
70
78
 
@@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
89
97
  loss_ptr += program_id * loss_stride
90
98
  z_loss_ptr += program_id * loss_stride
91
99
 
100
+ if HAS_WEIGHT:
101
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
102
+
92
103
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
93
104
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
94
105
 
@@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
117
128
  block_max = tl.max(X_block)
118
129
  if label_smoothing > 0:
119
130
  # scale X beforehand to avoid overflow
120
- scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
131
+ if HAS_WEIGHT:
132
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
133
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
134
+ else:
135
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
121
136
  m_new = tl.maximum(m, block_max)
122
137
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
123
138
  m = m_new
@@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
153
168
  if HAS_SOFTCAPPING:
154
169
  intermediate = tanh(X_block / softcap)
155
170
  X_block = softcap * intermediate
156
- # softmax(x_i)
157
- X_block = tl.exp(X_block - m) / d
158
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
159
- X_block += 2 * lse_square_scale * lse * X_block
160
- # smoothing term
161
- X_block += -eps
162
- # special handle dx_y
163
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
164
- # reduction scale
165
- if reduction == "mean":
166
- X_block = X_block / (n_non_ignore)
167
- # chain rule
171
+
172
+ if not HAS_WEIGHT:
173
+ # softmax(x_i)
174
+ X_block = tl.exp(X_block - m) / d
175
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
176
+ X_block += 2 * lse_square_scale * lse * X_block
177
+ # smoothing term
178
+ X_block += -eps
179
+ # special handle dx_y
180
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
181
+ # reduction scale
182
+ if reduction == "mean":
183
+ X_block = X_block / n_non_ignore
184
+ else:
185
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
186
+ softmax_X = tl.exp(X_block - m) / d
187
+ # derivative of original_loss
188
+ dloss_ori = (1 - label_smoothing) * softmax_X
189
+ # specially handle dx_y
190
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
191
+ dloss_ori = dloss_ori * weight_y
192
+ # derivative of smooth_loss
193
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
194
+ # derivative of z-loss
195
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
196
+ # reduction scale
197
+ if reduction == "mean":
198
+ dloss_ori = dloss_ori / sum_non_ignore_weight
199
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
200
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
201
+ dz_loss = dz_loss / n_non_ignore
202
+ # derivative of total_loss
203
+ X_block = dloss_ori + dloss_smooth + dz_loss
204
+
205
+ # chain rule softcapping
168
206
  # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
169
207
  if HAS_SOFTCAPPING:
170
208
  X_block = X_block * (1 - intermediate * intermediate)
@@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
183
221
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
184
222
  # So we can safely calculate log (softmax(X_y)) without overflow
185
223
  loss = lse - ori_X_y
224
+ if HAS_WEIGHT:
225
+ loss = weight_y * loss
186
226
 
187
227
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
188
228
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
@@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
193
233
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
194
234
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
195
235
  if label_smoothing > 0:
196
- smooth_loss = scaled_x_sum + label_smoothing * lse
236
+ if HAS_WEIGHT:
237
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
238
+ else:
239
+ smooth_loss = scaled_x_sum + label_smoothing * lse
197
240
  loss = loss * (1 - label_smoothing) + smooth_loss
198
241
 
199
242
  # An auxiliary loss, z_loss
200
243
  # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
201
244
  z_loss = lse_square_scale * lse * lse
202
- loss += z_loss
203
245
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
204
246
  if reduction == "mean":
247
+ if HAS_WEIGHT:
248
+ loss = loss / sum_non_ignore_weight
249
+ else:
250
+ loss = loss / n_non_ignore
251
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
205
252
  z_loss = z_loss / n_non_ignore
206
- loss = loss / n_non_ignore
253
+ loss += z_loss
207
254
 
208
255
  tl.store(loss_ptr, loss)
209
256
  if RETURN_Z_LOSS == _TRUE:
@@ -225,6 +272,7 @@ _bool_to_return_z_loss = {
225
272
  def cross_entropy_forward(
226
273
  _input,
227
274
  target,
275
+ weight,
228
276
  ignore_index,
229
277
  lse_square_scale,
230
278
  label_smoothing,
@@ -250,7 +298,20 @@ def cross_entropy_forward(
250
298
  else:
251
299
  z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
252
300
 
253
- n_non_ignore = (target != ignore_index).sum().item()
301
+ target_mask = target != ignore_index
302
+ n_non_ignore = target_mask.sum().item()
303
+ sum_non_ignore_weight = n_non_ignore
304
+ weight_sum = 0.0
305
+ if weight is not None:
306
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
307
+ assert torch.is_floating_point(
308
+ weight
309
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
310
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
311
+ weight_sum = weight.sum().item()
312
+ # ensure weight is contiguous
313
+ if weight.stride(-1) != 1:
314
+ weight = weight.contiguous()
254
315
 
255
316
  # ensure _input and target are contiguous in the last dimension
256
317
  if _input.stride(-1) != 1:
@@ -264,18 +325,22 @@ def cross_entropy_forward(
264
325
  X_stride=_input.stride(-2),
265
326
  Y_ptr=target,
266
327
  Y_stride=target.stride(-1), # always 1
328
+ weight_ptr=weight if weight is not None else _input, # dummy if None
267
329
  loss_ptr=loss_1d,
268
330
  z_loss_ptr=z_loss_1d,
269
331
  loss_stride=loss_1d.stride(-1), # always 1
270
332
  n_cols=V,
271
333
  n_non_ignore=n_non_ignore,
334
+ sum_non_ignore_weight=sum_non_ignore_weight,
272
335
  ignore_index=ignore_index,
336
+ weight_sum=weight_sum,
273
337
  lse_square_scale=lse_square_scale,
274
338
  label_smoothing=label_smoothing,
275
339
  reduction=reduction,
276
340
  softcap=softcap if softcap is not None else 0.0,
277
341
  RETURN_Z_LOSS=return_z_loss,
278
342
  BLOCK_SIZE=BLOCK_SIZE,
343
+ HAS_WEIGHT=True if weight is not None else False,
279
344
  HAS_SOFTCAPPING=True if softcap is not None else False,
280
345
  # TODO: 32 seems to give the best performance
281
346
  # Performance is quite sensitive to num_warps
@@ -327,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
327
392
  ctx,
328
393
  _input: torch.Tensor,
329
394
  target: torch.Tensor,
395
+ weight: Optional[torch.FloatTensor],
330
396
  ignore_index: int = -100,
331
397
  lse_square_scale: float = 0.0,
332
398
  label_smoothing: float = 0.0,
@@ -341,6 +407,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
341
407
  ctx : The context object.
342
408
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
343
409
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
410
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
344
411
  ignore_index (int): The index to ignore in the target.
345
412
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
346
413
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
@@ -354,6 +421,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
354
421
  loss, z_loss, _input = cross_entropy_forward(
355
422
  _input,
356
423
  target,
424
+ weight,
357
425
  ignore_index,
358
426
  lse_square_scale,
359
427
  label_smoothing,
@@ -395,4 +463,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
463
  None,
396
464
  None,
397
465
  None,
466
+ None,
398
467
  )
@@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward(
17
17
  _input,
18
18
  weight,
19
19
  target,
20
+ ce_weight=None,
20
21
  bias=None,
21
22
  ignore_index=-100,
22
23
  lse_square_scale=0.0,
@@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward(
47
48
  # we use fp32 for loss accumulator
48
49
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
49
50
 
50
- # NOTE: skip .item() here to avoid CUDA synchronization
51
- total_n_non_ignore = (target != ignore_index).sum()
51
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
52
+ target_mask = target != ignore_index
53
+ total_n_non_ignore = target_mask.sum().item()
54
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
55
+ ce_weight_sum = 0.0
56
+ if ce_weight is not None:
57
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
58
+ assert torch.is_floating_point(
59
+ ce_weight
60
+ ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
61
+ total_sum_non_ignore_ce_weight = (
62
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
63
+ )
64
+ ce_weight_sum = ce_weight.sum().item()
65
+ if ce_weight.stride(-1) != 1:
66
+ ce_weight = ce_weight.contiguous()
52
67
 
53
68
  for chunk_id in range(num_chunks):
54
69
  start_idx = chunk_id * chunk_size
@@ -66,7 +81,6 @@ def fused_linear_cross_entropy_forward(
66
81
 
67
82
  # unreduced loss
68
83
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
69
- n_non_ignore = (target_chunk != ignore_index).sum().item()
70
84
 
71
85
  # ensure _input and target are contiguous
72
86
  logits_chunk = logits_chunk.contiguous()
@@ -78,35 +92,28 @@ def fused_linear_cross_entropy_forward(
78
92
  X_stride=logits_chunk.stride(-2),
79
93
  Y_ptr=target_chunk,
80
94
  Y_stride=target_chunk.stride(-1), # always 1
95
+ weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
81
96
  loss_ptr=loss_1d_slice,
82
97
  z_loss_ptr=loss_1d_slice, # dummy ptr, not used
83
98
  loss_stride=loss_1d_slice.stride(-1), # always 1
84
99
  n_cols=V,
85
- n_non_ignore=n_non_ignore,
100
+ n_non_ignore=total_n_non_ignore,
101
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
102
+ weight_sum=ce_weight_sum,
86
103
  ignore_index=ignore_index,
87
104
  lse_square_scale=lse_square_scale,
88
105
  label_smoothing=label_smoothing,
89
106
  reduction=reduction,
90
107
  softcap=softcap if softcap is not None else 0.0,
91
108
  RETURN_Z_LOSS=0, # False
109
+ HAS_WEIGHT=True if ce_weight is not None else False,
92
110
  HAS_SOFTCAPPING=True if softcap is not None else False,
93
111
  BLOCK_SIZE=BLOCK_SIZE,
94
112
  num_warps=32 if not is_hip() else 16,
95
113
  )
96
114
 
97
- # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
98
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
99
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
100
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
101
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
102
-
103
- if reduction == "mean":
104
- alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
105
- else:
106
- alpha = 1.0
107
-
108
- loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
109
- grad_logits_chunk = logits_chunk * alpha # chunk_size x V
115
+ loss_1d[start_idx:end_idx] = loss_1d_slice
116
+ grad_logits_chunk = logits_chunk # chunk_size x V
110
117
 
111
118
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
112
119
 
@@ -118,7 +125,7 @@ def fused_linear_cross_entropy_forward(
118
125
  ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
119
126
  mat2=_input_chunk,
120
127
  out=grad_weight,
121
- alpha=alpha,
128
+ alpha=1.0,
122
129
  beta=1.0,
123
130
  )
124
131
 
@@ -127,7 +134,7 @@ def fused_linear_cross_entropy_forward(
127
134
  input=grad_bias,
128
135
  other=logits_chunk.sum(dim=0),
129
136
  out=grad_bias,
130
- alpha=alpha,
137
+ alpha=1.0,
131
138
  )
132
139
 
133
140
  if reduction == "none":
@@ -193,6 +200,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
193
200
  weight,
194
201
  target,
195
202
  bias=None,
203
+ ce_weight=None,
196
204
  ignore_index=-100,
197
205
  lse_square_scale=0.0,
198
206
  label_smoothing=0.0,
@@ -212,21 +220,23 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
212
220
  target: (B*T) where each value is in [0, V-1]
213
221
  weight: (V, H) where V is the number of classes
214
222
  bias: (V) where V is the number of classes
223
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
215
224
  ignore_index: the index to ignore in the target
216
225
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
217
226
  reduction: reduction to apply
218
227
  """
219
228
 
220
229
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
221
- _input,
222
- weight,
223
- target,
224
- bias,
225
- ignore_index,
226
- lse_square_scale,
227
- label_smoothing,
228
- reduction,
229
- softcap,
230
+ _input=_input,
231
+ weight=weight,
232
+ target=target,
233
+ bias=bias,
234
+ ce_weight=ce_weight,
235
+ ignore_index=ignore_index,
236
+ lse_square_scale=lse_square_scale,
237
+ label_smoothing=label_smoothing,
238
+ reduction=reduction,
239
+ softcap=softcap,
230
240
  )
231
241
  # downcast to dtype and store for backward
232
242
  ctx.save_for_backward(
@@ -243,4 +253,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
243
253
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
244
254
  grad_output, grad_input, grad_weight, grad_bias
245
255
  )
246
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
256
+ return (
257
+ grad_input,
258
+ grad_weight,
259
+ None,
260
+ grad_bias,
261
+ None,
262
+ None,
263
+ None,
264
+ None,
265
+ None,
266
+ None,
267
+ )
@@ -8,6 +8,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
8
8
  class LigerCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -28,6 +29,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
28
29
  "none",
29
30
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
31
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
32
+ self.weight = weight
31
33
  self.ignore_index = ignore_index
32
34
  self.lse_square_scale = lse_square_scale
33
35
  self.label_smoothing = label_smoothing
@@ -39,6 +41,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
39
41
  loss, z_loss = LigerCrossEntropyFunction.apply(
40
42
  _input,
41
43
  target,
44
+ self.weight,
42
45
  self.ignore_index,
43
46
  self.lse_square_scale,
44
47
  self.label_smoothing,
@@ -32,6 +32,7 @@ def liger_cross_entropy(
32
32
  loss, z_loss = LigerCrossEntropyFunction.apply(
33
33
  input,
34
34
  target,
35
+ weight,
35
36
  ignore_index,
36
37
  lse_square_scale,
37
38
  label_smoothing,
@@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy(
49
50
  weight,
50
51
  target,
51
52
  bias=None,
53
+ ce_weight=None,
52
54
  ignore_index: int = -100,
53
55
  lse_square_scale: float = 0.0,
54
56
  label_smoothing: float = 0.0,
@@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy(
60
62
  weight,
61
63
  target,
62
64
  bias,
65
+ ce_weight,
63
66
  ignore_index,
64
67
  lse_square_scale,
65
68
  label_smoothing,
@@ -8,6 +8,7 @@ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEnt
8
8
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
9
9
  def __init__(
10
10
  self,
11
+ ce_weight: Optional[torch.FloatTensor] = None,
11
12
  ignore_index: int = -100,
12
13
  lse_square_scale: float = 0.0,
13
14
  label_smoothing: float = 0.0,
@@ -24,6 +25,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
24
25
  "none",
25
26
  }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
28
+ self.ce_weight = ce_weight
27
29
  self.ignore_index = ignore_index
28
30
  self.lse_square_scale = lse_square_scale
29
31
  self.label_smoothing = label_smoothing
@@ -36,6 +38,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
36
38
  lin_weight,
37
39
  target,
38
40
  bias,
41
+ self.ce_weight,
39
42
  self.ignore_index,
40
43
  self.lse_square_scale,
41
44
  self.label_smoothing,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241229035411
3
+ Version: 0.5.2.dev20250101081922
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,13 +6,13 @@ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMH
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=H2z-BFd9pGATlEzEeOw4EZwMoWsZtD8ovWJTkHD-9-s,9592
14
+ liger_kernel/ops/cross_entropy.py,sha256=4zSPzdPl-d2tB3ZOj7uRMpzI4RzZMNLUzkh6eMkH5kU,19179
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=j7cgR95rFAwtPsWZ00PfMwis5F7dtO3EVEw0rZ1GPJk,10231
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
17
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
18
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -28,9 +28,9 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
28
28
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
29
29
  liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
30
30
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
31
- liger_kernel/transformers/cross_entropy.py,sha256=s5-ZM1NBMDjG-KKJKBtIkmArj1jCUjDnpL-2QKhKYho,1734
32
- liger_kernel/transformers/functional.py,sha256=hxReSBDEUZkOnZgURD8sf6ETYvf9yqCOOMU2k9Ywh90,4435
33
- liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=K4tfpoNPUJpWv7rCHEcs5xhJLg5td8GcpJrAryF5NMk,1451
31
+ liger_kernel/transformers/cross_entropy.py,sha256=s931h9UW_tV4QMRme1HYjS_R2_C5nD6VFmZIXtjJoYo,1840
32
+ liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
33
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
34
34
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
35
35
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
36
36
  liger_kernel/transformers/group_norm.py,sha256=URmjkQFsrbMffzcJiGpX7ckxWlpL95AiJS-80hwAWPk,2173
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/METADATA,sha256=bVGSgTflxiXCSgDtaCWRTo93kcV2WSuSFYnfDHI4XIw,21055
63
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/METADATA,sha256=8p2CjwfCCe9ECXdhglWrNPw2cQr2lZSLrkX6Nrg_xIQ,21055
63
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250101081922.dist-info/RECORD,,