liger-kernel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,57 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute DPO loss (Direct Preference Optimization).
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the direct preference loss.
18
+ """
19
+ logits_diff = beta * (chosen_logps - rejected_logps)
20
+ losses = -F.logsigmoid(logits_diff)
21
+ return losses.sum()
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ _input,
27
+ weight,
28
+ target,
29
+ bias=None,
30
+ ignore_index=-100,
31
+ beta=0.1,
32
+ compute_nll_loss=True,
33
+ compiled=True,
34
+ ):
35
+ """
36
+ Fused linear layer with DPO (Direct Preference Optimization) loss.
37
+ Handles both the forward and backward pass of the final linear layer with DPO loss.
38
+ """
39
+ return LigerFusedLinearPreferenceBase.forward(
40
+ ctx=ctx,
41
+ _input=_input,
42
+ weight=weight,
43
+ target=target,
44
+ bias=bias,
45
+ loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46
+ compute_nll_loss=compute_nll_loss,
47
+ ignore_index=ignore_index,
48
+ beta=beta,
49
+ compiled=compiled,
50
+ )
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_output):
54
+ # Get gradients for _input, weight, bias, and target from the base class
55
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
56
+ # Return these gradients, followed by None for the remaining inputs
57
+ return *grads, None, None, None, None
@@ -0,0 +1,206 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
+
10
+ @abstractmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute preference loss.
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the odds ratio loss.
18
+ """
19
+ raise NotImplementedError("Preference loss function must be implemented.")
20
+
21
+ @staticmethod
22
+ def forward(
23
+ ctx,
24
+ _input,
25
+ weight,
26
+ target,
27
+ bias=None,
28
+ loss_fn=None,
29
+ chunk_size=1,
30
+ compute_nll_loss=True,
31
+ ignore_index=-100,
32
+ beta=0.1,
33
+ compiled=True,
34
+ ):
35
+ """
36
+ Base class for fused linear layer with preference loss.
37
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
38
+
39
+ Args:
40
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
41
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
42
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
43
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
44
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
45
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46
+ compute_nll_loss (bool): Whether to compute NLL loss.
47
+ ignore_index (int): Index to ignore for loss computation.
48
+ beta (float): Weight for the odds ratio loss.
49
+ compiled (bool): Whether to use torch compile for chunk accumulation.
50
+ """
51
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
52
+ CHUNK_SIZE = chunk_size
53
+
54
+ grad_weight = torch.zeros_like(weight)
55
+ grad_chosen_inputs = []
56
+ grad_rejected_inputs = []
57
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
58
+ loss_acc = torch.zeros((), device=_input.device)
59
+
60
+ chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
61
+ loss_func_to_call = partial(
62
+ LigerFusedLinearPreferenceBase._compute_loss,
63
+ preference_loss_fn=loss_fn,
64
+ ignore_index=ignore_index,
65
+ beta=beta,
66
+ compute_nll_loss=compute_nll_loss,
67
+ full_target=target,
68
+ )
69
+
70
+ def accumulate_chunk(input_chunk, target_chunk):
71
+ if bias is not None:
72
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
73
+ chunk_loss,
74
+ (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
75
+ ) = torch.func.grad_and_value(
76
+ loss_func_to_call, argnums=(0, 1, 3), has_aux=True
77
+ )(
78
+ input_chunk, weight, target_chunk, bias
79
+ )
80
+ grad_bias.add_(chunk_grad_bias)
81
+ else:
82
+ (chunk_grad_input, chunk_grad_weight), (
83
+ chunk_loss,
84
+ (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
85
+ ) = torch.func.grad_and_value(
86
+ loss_func_to_call, argnums=(0, 1), has_aux=True
87
+ )(
88
+ input_chunk, weight, target_chunk
89
+ )
90
+ grad_weight.add_(chunk_grad_weight)
91
+ loss_acc.add_(chunk_loss)
92
+ return chunk_grad_input
93
+
94
+ len_chosen = target.shape[0] // 2
95
+ _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
96
+ _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
97
+ _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
98
+ _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
99
+
100
+ for (
101
+ chosen_input_chunk,
102
+ rejected_input_chunk,
103
+ chosen_target_chunk,
104
+ rejected_target_chunk,
105
+ ) in zip(
106
+ _chosen_input_chunks,
107
+ _rejected_input_chunks,
108
+ _chosen_target_chunks,
109
+ _rejected_target_chunks,
110
+ ):
111
+ input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
112
+ target_chunk = torch.cat(
113
+ [chosen_target_chunk, rejected_target_chunk], dim=0
114
+ )
115
+
116
+ if compiled:
117
+ accumulate_chunk = torch.compile(accumulate_chunk)
118
+ grad_input = accumulate_chunk(input_chunk, target_chunk)
119
+
120
+ grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
121
+ grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
122
+
123
+ # combine grad_chosen_inputs and grad_rejected_inputs
124
+ grad_inputs = grad_chosen_inputs + grad_rejected_inputs
125
+
126
+ ctx.save_for_backward(
127
+ torch.cat(grad_inputs, dim=0),
128
+ grad_weight,
129
+ grad_bias,
130
+ )
131
+ return loss_acc
132
+
133
+ @staticmethod
134
+ def backward(ctx, grad_output):
135
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
136
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
+ grad_input = grad_input * grad_output
138
+ grad_weight = grad_weight * grad_output
139
+ grad_bias = grad_bias * grad_output if grad_bias is not None else None
140
+
141
+ return grad_input, grad_weight, None, grad_bias, None, None, None
142
+
143
+ @staticmethod
144
+ def _compute_loss(
145
+ input_chunk,
146
+ weight,
147
+ target_chunk,
148
+ bias=None,
149
+ preference_loss_fn=None,
150
+ full_target=None,
151
+ ignore_index=-100,
152
+ beta=0.1,
153
+ compute_nll_loss=True,
154
+ **loss_kwargs,
155
+ ):
156
+ """
157
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
158
+ Args:
159
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
160
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
161
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
162
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
163
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165
+ ignore_index (int): Index to ignore for loss computation.
166
+ beta (float): Weight for the odds ratio loss.
167
+ loss_kwargs (dict): Additional arguments for the loss function.
168
+ """
169
+ len_chosen_chunk = target_chunk.shape[0] // 2
170
+
171
+ logits_chunk = input_chunk @ weight.t() # chunk_size x V
172
+ if bias is not None:
173
+ logits_chunk = logits_chunk + bias
174
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
175
+
176
+ chosen_nll_loss = 0.0
177
+ if compute_nll_loss:
178
+ chosen_nll_loss = F.nll_loss(
179
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
180
+ target_chunk[:len_chosen_chunk].view(-1),
181
+ reduction="sum",
182
+ ignore_index=ignore_index,
183
+ )
184
+ chosen_nll_loss = (
185
+ chosen_nll_loss
186
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
187
+ )
188
+
189
+ loss_mask = target_chunk != ignore_index
190
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
191
+
192
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
193
+ -1
194
+ )
195
+ average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
196
+
197
+ chosen_logps = average_log_prob[:len_chosen_chunk]
198
+ rejected_logps = average_log_prob[len_chosen_chunk:]
199
+
200
+ alignment_loss = preference_loss_fn(
201
+ chosen_logps, rejected_logps, beta=beta, **loss_kwargs
202
+ )
203
+ alignment_loss = alignment_loss / (full_target.shape[0] // 2)
204
+
205
+ loss = chosen_nll_loss - alignment_loss
206
+ return loss, (alignment_loss, chosen_logps, rejected_logps)
@@ -0,0 +1,63 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
13
+ """
14
+ Compute odds-ratio loss.
15
+ Args:
16
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
18
+ beta (float): Weight for the odds ratio loss.
19
+ """
20
+ log_odds = (chosen_logps - rejected_logps) - (
21
+ torch.log1p(-torch.exp(chosen_logps))
22
+ - torch.log1p(-torch.exp(rejected_logps))
23
+ )
24
+ ratio = F.logsigmoid(log_odds)
25
+ return beta * ratio.sum()
26
+
27
+ @staticmethod
28
+ def forward(
29
+ ctx,
30
+ _input,
31
+ weight,
32
+ target,
33
+ bias=None,
34
+ ignore_index=-100,
35
+ beta=0.1,
36
+ compute_nll_loss=True,
37
+ compiled=True,
38
+ ):
39
+ """
40
+ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
41
+ Handles both the forward and backward pass of the final linear layer with ORPO loss.
42
+ Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
43
+ """
44
+
45
+ return LigerFusedLinearPreferenceBase.forward(
46
+ ctx=ctx,
47
+ _input=_input,
48
+ weight=weight,
49
+ target=target,
50
+ bias=bias,
51
+ loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
52
+ compute_nll_loss=compute_nll_loss,
53
+ ignore_index=ignore_index,
54
+ beta=beta,
55
+ compiled=compiled,
56
+ )
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ # Get gradients for _input, weight, bias, and target from the base class
61
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
62
+ # Return these gradients, followed by None for the remaining inputs
63
+ return *grads, None, None, None, None
@@ -4,11 +4,13 @@ import sys
4
4
 
5
5
  def print_env_report():
6
6
  """
7
+
7
8
  Prints a report of the environment. Useful for debugging and reproducibility.
8
9
  Usage:
9
10
  ```
10
11
  python -m liger_kernel.env_report
11
12
  ```
13
+
12
14
  """
13
15
  print("Environment Report:")
14
16
  print("-------------------")
@@ -1,8 +1,24 @@
1
+ import operator
2
+ from typing import Optional
3
+
1
4
  import torch
2
5
  import triton
3
6
  import triton.language as tl
4
7
 
5
- from liger_kernel.ops.utils import element_mul_kernel, is_hip
8
+ from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import tanh
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import tanh
17
+ else:
18
+ from triton.language.math import tanh
19
+
20
+ _TRUE = tl.constexpr(1)
21
+ _FALSE = tl.constexpr(0)
6
22
 
7
23
 
8
24
  @triton.jit
@@ -12,13 +28,18 @@ def liger_cross_entropy_kernel(
12
28
  Y_ptr,
13
29
  Y_stride,
14
30
  loss_ptr,
31
+ z_loss_ptr,
15
32
  loss_stride,
16
33
  n_cols,
17
34
  n_non_ignore,
18
35
  ignore_index,
36
+ lse_square_scale: tl.constexpr,
19
37
  label_smoothing: tl.constexpr,
20
38
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
39
+ softcap,
40
+ RETURN_Z_LOSS: tl.constexpr,
21
41
  BLOCK_SIZE: tl.constexpr,
42
+ HAS_SOFTCAPPING: tl.constexpr,
22
43
  ):
23
44
  """
24
45
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -30,13 +51,18 @@ def liger_cross_entropy_kernel(
30
51
  Y_ptr: Pointer to target tensor.
31
52
  Y_stride (int): The stride of the target tensor.
32
53
  loss_ptr: Pointer to tensor to store the loss.
54
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
33
55
  loss_stride (int): The stride of the loss tensor.
34
56
  n_cols (int): The number of columns in the input tensor.
35
57
  n_non_ignore (int): The number of non-ignored elements in the batch.
36
58
  ignore_index (int): The index to ignore in the target.
37
59
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
60
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
61
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
38
62
  reduction (str): The string for the reduction to apply
63
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
39
64
  BLOCK_SIZE (int): The block size for Triton operations.
65
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
40
66
  """
41
67
 
42
68
  # https://github.com/triton-lang/triton/issues/1058
@@ -58,6 +84,7 @@ def liger_cross_entropy_kernel(
58
84
  return
59
85
 
60
86
  loss_ptr += program_id * loss_stride
87
+ z_loss_ptr += program_id * loss_stride
61
88
 
62
89
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
63
90
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
@@ -68,6 +95,8 @@ def liger_cross_entropy_kernel(
68
95
  ori_X_y = tl.load(
69
96
  X_ptr + y
70
97
  ) # we need to store the original value of X_y for the loss calculation
98
+ if HAS_SOFTCAPPING:
99
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
71
100
 
72
101
  # Label smoothing is a general case of normal cross entropy
73
102
  # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
@@ -79,6 +108,8 @@ def liger_cross_entropy_kernel(
79
108
  X_block = tl.load(
80
109
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
81
110
  )
111
+ if HAS_SOFTCAPPING:
112
+ X_block = softcap * tanh(X_block / softcap)
82
113
  block_max = tl.max(X_block)
83
114
  if label_smoothing > 0:
84
115
  # scale X beforehand to avoid overflow
@@ -87,32 +118,49 @@ def liger_cross_entropy_kernel(
87
118
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
88
119
  m = m_new
89
120
 
121
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
122
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
123
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
124
+ lse = m + tl.log(d)
125
+
90
126
  # 4. [Online Softmax] Second pass: compute gradients
91
127
  # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
92
128
  # dx_y = (softmax(x_y) - 1) / N
93
129
  # dx_i = softmax(x_i) / N, i != y
94
130
  # For label smoothing:
95
- # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
131
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
96
132
  # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
97
133
  # = dx_i - (1 - label_smoothing) / N
98
- #
134
+ # With Z loss:
135
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
136
+ # dx_y = dx_i - (1 - label_smoothing) / N
99
137
  # For 'sum' reduction, no normalization is applied:
100
138
  # dx_y = softmax(x_y) - 1
101
139
  # dx_i = softmax(x_i), for i ≠ y
102
- # For label smoothing:
103
- # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
104
- # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
105
- # = dx_i - (1 - label_smoothing)
106
140
 
107
141
  for i in range(0, n_cols, BLOCK_SIZE):
108
142
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
109
143
  X_block = tl.load(
110
144
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
111
145
  )
146
+ if HAS_SOFTCAPPING:
147
+ intermediate = tanh(X_block / softcap)
148
+ X_block = softcap * intermediate
149
+ # softmax(x_i)
150
+ X_block = tl.exp(X_block - m) / d
151
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
152
+ X_block += 2 * lse_square_scale * lse * X_block
153
+ # smoothing term
154
+ X_block += -eps
155
+ # special handle dx_y
156
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
157
+ # reduction scale
112
158
  if reduction == "mean":
113
- X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
114
- else:
115
- X_block = tl.exp(X_block - m) / d - eps
159
+ X_block = X_block / (n_non_ignore)
160
+ # chain rule
161
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
162
+ if HAS_SOFTCAPPING:
163
+ X_block = X_block * (1 - intermediate * intermediate)
116
164
 
117
165
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
118
166
 
@@ -124,35 +172,35 @@ def liger_cross_entropy_kernel(
124
172
 
125
173
  # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
126
174
  # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
175
+ # = X_y - m - log d = X_y - lse
127
176
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
128
177
  # So we can safely calculate log (softmax(X_y)) without overflow
129
- loss = -(ori_X_y - m - tl.log(d))
178
+ loss = lse - ori_X_y
130
179
 
131
180
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
132
181
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
133
182
  # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
134
183
  # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
135
- # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
184
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
136
185
  # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
137
186
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
138
187
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
139
188
  if label_smoothing > 0:
140
- smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
189
+ smooth_loss = scaled_x_sum + label_smoothing * lse
141
190
  loss = loss * (1 - label_smoothing) + smooth_loss
142
191
 
192
+ # An auxiliary loss, z_loss
193
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
194
+ z_loss = lse_square_scale * lse * lse
195
+ loss += z_loss
143
196
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
144
197
  if reduction == "mean":
198
+ z_loss = z_loss / n_non_ignore
145
199
  loss = loss / n_non_ignore
146
200
 
147
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
148
- X_y = tl.load(X_ptr + y)
149
- if reduction == "mean":
150
- X_y += -(1 - label_smoothing) / (n_non_ignore)
151
- else:
152
- X_y += -(1 - label_smoothing)
153
-
154
201
  tl.store(loss_ptr, loss)
155
- tl.store(X_ptr + y, X_y)
202
+ if RETURN_Z_LOSS == _TRUE:
203
+ tl.store(z_loss_ptr, z_loss)
156
204
 
157
205
 
158
206
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -161,7 +209,32 @@ def liger_cross_entropy_kernel(
161
209
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
162
210
 
163
211
 
164
- def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
212
+ _bool_to_return_z_loss = {
213
+ True: _TRUE.value,
214
+ False: _FALSE.value,
215
+ }
216
+
217
+
218
+ def cross_entropy_forward(
219
+ _input,
220
+ target,
221
+ ignore_index,
222
+ lse_square_scale,
223
+ label_smoothing,
224
+ reduction,
225
+ softcap,
226
+ return_z_loss,
227
+ ):
228
+ if not isinstance(return_z_loss, int):
229
+ assert (
230
+ return_z_loss in _bool_to_return_z_loss
231
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
232
+ return_z_loss = _bool_to_return_z_loss[return_z_loss]
233
+ else:
234
+ assert (
235
+ return_z_loss in _bool_to_return_z_loss
236
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
237
+
165
238
  BT, V = _input.shape
166
239
  n_rows = BT
167
240
 
@@ -169,6 +242,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
169
242
 
170
243
  # unreduced loss
171
244
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
245
+ if return_z_loss == _TRUE.value:
246
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
247
+ else:
248
+ z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
172
249
 
173
250
  n_non_ignore = (target != ignore_index).sum().item()
174
251
 
@@ -185,20 +262,30 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
185
262
  Y_ptr=target,
186
263
  Y_stride=target.stride(-1), # always 1
187
264
  loss_ptr=loss_1d,
265
+ z_loss_ptr=z_loss_1d,
188
266
  loss_stride=loss_1d.stride(-1), # always 1
189
267
  n_cols=V,
190
268
  n_non_ignore=n_non_ignore,
191
269
  ignore_index=ignore_index,
270
+ lse_square_scale=lse_square_scale,
192
271
  label_smoothing=label_smoothing,
193
272
  reduction=reduction,
273
+ softcap=softcap if softcap is not None else 0.0,
274
+ RETURN_Z_LOSS=return_z_loss,
194
275
  BLOCK_SIZE=BLOCK_SIZE,
276
+ HAS_SOFTCAPPING=True if softcap is not None else False,
195
277
  # TODO: 32 seems to give the best performance
196
278
  # Performance is quite sensitive to num_warps
197
279
  num_warps=32 if not is_hip() else 16,
198
280
  )
199
281
 
200
282
  loss = torch.sum(loss_1d)
201
- return loss, _input
283
+ if return_z_loss == _TRUE.value:
284
+ z_loss = torch.sum(z_loss_1d)
285
+ else:
286
+ z_loss = None
287
+
288
+ return loss, z_loss, _input
202
289
 
203
290
 
204
291
  def cross_entropy_backward(_input, grad_output):
@@ -233,7 +320,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
233
320
 
234
321
  @staticmethod
235
322
  def forward(
236
- ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
323
+ ctx,
324
+ _input: torch.Tensor,
325
+ target: torch.Tensor,
326
+ ignore_index: int = -100,
327
+ lse_square_scale: float = 0.0,
328
+ label_smoothing: float = 0.0,
329
+ reduction: str = "mean",
330
+ softcap: Optional[float] = None,
331
+ return_z_loss: bool = False,
237
332
  ):
238
333
  """
239
334
  The forward pass of the Liger Cross Entropy loss.
@@ -243,33 +338,48 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
243
338
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
244
339
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
245
340
  ignore_index (int): The index to ignore in the target.
341
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
246
342
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
247
343
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
344
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
345
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
248
346
 
249
347
  Returns:
250
- tensor: The computed loss.
348
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
251
349
  """
252
- loss, _input = cross_entropy_forward(
253
- _input, target, ignore_index, label_smoothing, reduction
350
+ loss, z_loss, _input = cross_entropy_forward(
351
+ _input,
352
+ target,
353
+ ignore_index,
354
+ lse_square_scale,
355
+ label_smoothing,
356
+ reduction,
357
+ softcap,
358
+ return_z_loss,
254
359
  )
255
360
  # TODO: investigation
256
361
  # If we don't detach the _input tensor, the memory will double
257
362
  # Not sure why but seems that there will be a time both grad and value exist but in different location
258
363
  ctx.save_for_backward(_input.detach())
259
- return loss
364
+ ctx.return_z_loss = return_z_loss
365
+
366
+ return loss, z_loss
260
367
 
261
368
  @staticmethod
262
- def backward(ctx, grad_output):
369
+ def backward(ctx, grad_output, grad_ouput2):
263
370
  """
264
371
  The backward pass of the Liger Cross Entropy loss.
265
372
 
266
373
  Parameters:
267
374
  ctx : The context object with saved tensors.
268
375
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
269
-
376
+ grad_output2 (tenosr): No use.
270
377
  Returns:
271
378
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
272
379
  """
380
+ if ctx.return_z_loss:
381
+ del grad_ouput2 # z_loss is only for logging
382
+
273
383
  (_input,) = ctx.saved_tensors
274
384
  _input = cross_entropy_backward(_input, grad_output)
275
385
  return (
@@ -278,4 +388,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
278
388
  None,
279
389
  None,
280
390
  None,
391
+ None,
392
+ None,
393
+ None,
281
394
  )