liger-kernel 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +13 -10
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +45 -34
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/auto_model.py +18 -6
  13. liger_kernel/transformers/functional.py +4 -0
  14. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  15. liger_kernel/transformers/jsd.py +75 -0
  16. liger_kernel/transformers/kl_div.py +3 -2
  17. liger_kernel/transformers/model/gemma.py +124 -1
  18. liger_kernel/transformers/model/llama.py +135 -4
  19. liger_kernel/transformers/model/mistral.py +3 -0
  20. liger_kernel/transformers/model/mixtral.py +153 -2
  21. liger_kernel/transformers/model/mllama.py +274 -0
  22. liger_kernel/transformers/model/phi3.py +140 -2
  23. liger_kernel/transformers/model/qwen2.py +123 -2
  24. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  25. liger_kernel/transformers/monkey_patch.py +254 -129
  26. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +74 -35
  27. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  28. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  29. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  30. liger_kernel-0.3.0.dist-info/NOTICE +0 -4
  31. liger_kernel-0.3.0.dist-info/RECORD +0 -42
  32. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  33. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+
6
+ from liger_kernel.ops.jsd import _jsd_kernel
7
+ from liger_kernel.ops.utils import (
8
+ amp_custom_bwd,
9
+ amp_custom_fwd,
10
+ element_mul_kernel,
11
+ is_hip,
12
+ )
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
16
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
17
+ MAX_FUSED_SIZE = 65536 // 2
18
+
19
+
20
+ def fused_linear_jsd_forward(
21
+ student_input,
22
+ student_weight,
23
+ teacher_input,
24
+ teacher_weight,
25
+ shift_labels,
26
+ jsd_beta,
27
+ ignore_index,
28
+ has_label,
29
+ temperature,
30
+ ):
31
+ device = student_input.device
32
+ dtype = student_input.dtype
33
+
34
+ # inputs have shape: BT x H
35
+ # materialized activations will have shape: BT x V
36
+ # the increase in memory = BT x V
37
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
38
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
39
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
40
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
41
+ BT, H = student_input.shape
42
+ V = student_weight.shape[0]
43
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
+
45
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
+ chunk_size = triton.next_power_of_2(
47
+ triton.cdiv(BT, inc_factor)
48
+ ) # (BT + inc_factor - 1) // inc_factor
49
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
+
51
+ grad_weight = (
52
+ torch.zeros_like(student_weight, device=device)
53
+ if student_weight.requires_grad
54
+ else None
55
+ )
56
+ grad_input = torch.zeros_like(student_input)
57
+ # we use fp32 for loss accumulator
58
+ loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
59
+
60
+ if has_label:
61
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
62
+ else:
63
+ n_non_ignore = BT
64
+
65
+ for chunk_id in range(num_chunks):
66
+ start_idx = chunk_id * chunk_size
67
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
68
+
69
+ # chunk both inputs, shape: chunk_size x H
70
+ student_input_chunk = student_input[start_idx:end_idx]
71
+ teacher_input_chunk = teacher_input[start_idx:end_idx]
72
+
73
+ # shape: chunk_size x V
74
+ # For anything starting from logits to the final JSD loss, we do computation
75
+ # in FP32 to avoid losing numerical stability.
76
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
+ torch.float32
78
+ )
79
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
+ torch.float32
81
+ )
82
+ chunk_n_rows = student_logits_chunk.shape[0]
83
+
84
+ # unreduced loss
85
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
86
+ # log-softmax with temperature
87
+ student_logits_chunk = student_logits_chunk / temperature
88
+ teacher_logits_chunk = teacher_logits_chunk / temperature
89
+ student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
90
+ teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
91
+
92
+ # ensure _input and target are contiguous
93
+ student_prob_chunk = student_prob_chunk.contiguous()
94
+ teacher_prob_chunk = teacher_prob_chunk.contiguous()
95
+
96
+ # Here we calculate the gradient of prob_chunk in place so we can save memory.
97
+ _jsd_kernel[(chunk_n_rows,)](
98
+ X_ptr=student_prob_chunk,
99
+ X_stride=student_prob_chunk.stride(-2),
100
+ Y_ptr=teacher_prob_chunk,
101
+ Y_stride=teacher_prob_chunk.stride(-2),
102
+ loss_ptr=loss_1d_slice,
103
+ loss_stride=loss_1d_slice.stride(-2),
104
+ dX_ptr=student_prob_chunk,
105
+ dX_stride=student_prob_chunk.stride(-2),
106
+ label_ptr=(
107
+ shift_labels[start_idx:end_idx]
108
+ if has_label
109
+ else torch.empty(1, device=device)
110
+ ), # dummy ptr if no label
111
+ beta=jsd_beta,
112
+ n_non_ignore=n_non_ignore,
113
+ ignore_index=ignore_index,
114
+ n_cols=V,
115
+ BLOCK_SIZE=BLOCK_SIZE,
116
+ HAS_LABEL=has_label,
117
+ )
118
+ loss_1d[start_idx:end_idx] = loss_1d_slice
119
+ # gradients of prob_chunk in place, shape: chunk_size x V
120
+ # gradients of logits_chunk in place, shape: chunk_size x V
121
+ student_logits_chunk = (
122
+ student_prob_chunk
123
+ - torch.softmax(student_logits_chunk, dim=-1)
124
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
+ student_prob_chunk.shape
126
+ )
127
+ ) / temperature
128
+ # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
+ # w.r.t. `lm_head` which should be computed in original dtype
130
+ student_logits_chunk = student_logits_chunk.to(dtype)
131
+ grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
132
+
133
+ if grad_weight is not None:
134
+ grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
135
+
136
+ loss = torch.sum(loss_1d)
137
+ return loss, grad_input, grad_weight
138
+
139
+
140
+ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
141
+ # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
142
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
143
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
144
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
145
+ BT, H = grad_input.shape
146
+ n_rows = BT
147
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
148
+
149
+ element_mul_kernel[(n_rows,)](
150
+ grad_input,
151
+ grad_input.stride(-2),
152
+ grad_output,
153
+ H,
154
+ BLOCK_SIZE=BLOCK_SIZE,
155
+ num_warps=32 if not is_hip() else 16,
156
+ )
157
+
158
+ # handle grad_weight
159
+ if grad_weight is not None:
160
+ V, H = grad_weight.shape
161
+ n_rows = V
162
+
163
+ element_mul_kernel[(n_rows,)](
164
+ grad_weight,
165
+ grad_weight.stride(-2),
166
+ grad_output,
167
+ H,
168
+ BLOCK_SIZE=BLOCK_SIZE,
169
+ num_warps=32 if not is_hip() else 16,
170
+ )
171
+
172
+ return grad_input, grad_weight
173
+
174
+
175
+ class LigerFusedLinearJSDFunction(torch.autograd.Function):
176
+ """
177
+ Fusing the last linear layer with generalized JSD
178
+
179
+ Handle the forward and backward pass of the final linear layer via JSD by avoiding
180
+ the materialization of the large logits tensor. Since JSD is the last layer, we can
181
+ compute the gradient at the forward pass.
182
+ """
183
+
184
+ @staticmethod
185
+ @amp_custom_fwd
186
+ def forward(
187
+ ctx,
188
+ student_input: torch.Tensor,
189
+ student_weight: torch.Tensor,
190
+ teacher_input: torch.Tensor,
191
+ teacher_weight: torch.Tensor,
192
+ shift_labels: Optional[torch.Tensor] = None,
193
+ jsd_beta: float = 0.5,
194
+ ignore_index: int = -100,
195
+ temperature: float = 1.0,
196
+ ):
197
+ """
198
+ Args:
199
+
200
+ student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
201
+ student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
202
+ teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
+ teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
+ jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
206
+ ignore_index (int): the index to ignore. Default: -100
207
+ temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
+
209
+ Returns:
210
+ loss (torch.Tensor): generalized JSD
211
+ """
212
+ has_label = False
213
+ if shift_labels is not None:
214
+ assert shift_labels.shape == (
215
+ teacher_input.shape[0],
216
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
217
+ shift_labels = shift_labels.contiguous()
218
+ has_label = True
219
+
220
+ loss, grad_input, grad_weight = fused_linear_jsd_forward(
221
+ student_input,
222
+ student_weight,
223
+ teacher_input,
224
+ teacher_weight,
225
+ shift_labels,
226
+ jsd_beta,
227
+ ignore_index,
228
+ has_label,
229
+ temperature,
230
+ )
231
+ # downcast to dtype and store for backward
232
+ ctx.save_for_backward(
233
+ grad_input.detach(),
234
+ grad_weight.detach() if grad_weight is not None else None,
235
+ )
236
+ return loss
237
+
238
+ @staticmethod
239
+ @amp_custom_bwd
240
+ def backward(ctx, grad_output):
241
+ (grad_input, grad_weight) = ctx.saved_tensors
242
+ grad_input, grad_weight = fused_linear_jsd_backward(
243
+ grad_output, grad_input, grad_weight
244
+ )
245
+ return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -25,7 +25,7 @@ else:
25
25
  def _geglu_tanh_forward_kernel(
26
26
  a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
27
  ):
28
- program_id = tl.program_id(0)
28
+ program_id = tl.program_id(0).to(tl.int64)
29
29
 
30
30
  # locate start index
31
31
  a += program_id * stride
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
52
52
  def _geglu_tanh_backward_kernel(
53
53
  dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
54
  ):
55
- program_id = tl.program_id(0)
55
+ program_id = tl.program_id(0).to(tl.int64)
56
56
 
57
57
  # locate start index
58
58
  dc += program_id * stride
@@ -0,0 +1,176 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+
10
+ @triton.jit
11
+ def _jsd_kernel(
12
+ X_ptr, # input in logspace, X = log Q
13
+ X_stride,
14
+ Y_ptr, # ground truth in logspace, Y = log P
15
+ Y_stride,
16
+ loss_ptr,
17
+ loss_stride,
18
+ dX_ptr,
19
+ dX_stride,
20
+ label_ptr,
21
+ beta,
22
+ n_non_ignore: int,
23
+ ignore_index: tl.constexpr,
24
+ n_cols,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ HAS_LABEL: tl.constexpr,
27
+ ):
28
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
29
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
30
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
31
+ # grad_x_i = 0.5 * Q * (X - log_M)
32
+ pid = tl.program_id(0).to(tl.int64)
33
+ X_ptr += pid * X_stride
34
+ dX_ptr += pid * dX_stride
35
+ Y_ptr += pid * Y_stride
36
+ loss_ptr += pid * loss_stride
37
+ label_ptr += pid
38
+
39
+ if HAS_LABEL:
40
+ label = tl.load(label_ptr)
41
+ if label == ignore_index:
42
+ for i in range(0, n_cols, BLOCK_SIZE):
43
+ offsets = i + tl.arange(0, BLOCK_SIZE)
44
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
45
+ return
46
+
47
+ for i in range(0, n_cols, BLOCK_SIZE):
48
+ offsets = i + tl.arange(0, BLOCK_SIZE)
49
+ mask = offsets < n_cols
50
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+
53
+ Q = tl.exp(X)
54
+ P = tl.exp(Y)
55
+ M = beta * P + (1 - beta) * Q
56
+ log_M = tl.log(M)
57
+
58
+ loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
+ # reduction == "batchmean"
60
+ loss = loss / n_non_ignore
61
+ tl.store(loss_ptr + offsets, loss, mask=mask)
62
+
63
+ dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
+ tl.store(dX_ptr + offsets, dX, mask=mask)
65
+
66
+
67
+ MAX_FUSED_SIZE = 65536
68
+
69
+
70
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
71
+ BT, V = _input.shape
72
+ n_rows = BT
73
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
74
+ # non reduction loss
75
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
76
+ dX = torch.empty_like(_input)
77
+
78
+ if has_label:
79
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
80
+ else:
81
+ n_non_ignore = BT
82
+
83
+ _jsd_kernel[(n_rows,)](
84
+ X_ptr=_input, # input in logspace, X = log Q
85
+ X_stride=_input.stride(-2),
86
+ Y_ptr=target, # ground truth in logspace, Y = log P
87
+ Y_stride=target.stride(-2),
88
+ loss_ptr=loss,
89
+ loss_stride=loss.stride(-2),
90
+ dX_ptr=dX,
91
+ dX_stride=dX.stride(-2),
92
+ label_ptr=(
93
+ shift_labels if has_label else torch.empty(1, device=_input.device)
94
+ ), # dummy ptr if no label
95
+ beta=beta,
96
+ n_non_ignore=n_non_ignore,
97
+ ignore_index=ignore_index,
98
+ n_cols=V,
99
+ BLOCK_SIZE=BLOCK_SIZE,
100
+ HAS_LABEL=has_label,
101
+ )
102
+
103
+ loss = torch.sum(loss)
104
+ return loss.to(_input.dtype), dX
105
+
106
+
107
+ def jsd_backward(dX, grad_output):
108
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
109
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
110
+ return dX
111
+ else:
112
+ return grad_output * dX
113
+
114
+
115
+ class LigerJSDFunction(torch.autograd.Function):
116
+ r"""
117
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
118
+ .. math::
119
+ JSD(\beta)(P || Q)
120
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
121
+
122
+ .. note::
123
+ As all the other losses in PyTorch, this function expects the first argument,
124
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
125
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
126
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
127
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
128
+ """
129
+
130
+ @staticmethod
131
+ @ensure_contiguous
132
+ def forward(
133
+ ctx,
134
+ _input: torch.Tensor,
135
+ target: torch.Tensor,
136
+ shift_labels: Optional[torch.Tensor] = None,
137
+ beta: float = 0.5,
138
+ ignore_index: int = -100,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
+ beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
146
+ ignore_index (int): the index to ignore. Default: -100
147
+
148
+ Returns:
149
+ loss (torch.Tensor): generalized JSD
150
+ """
151
+ has_label = False
152
+ if shift_labels is not None:
153
+ assert shift_labels.shape == (
154
+ _input.shape[0],
155
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
156
+ shift_labels = shift_labels.contiguous()
157
+ has_label = True
158
+
159
+ loss, dX = jsd_forward(
160
+ _input, target, shift_labels, beta, ignore_index, has_label
161
+ )
162
+ ctx.save_for_backward(dX)
163
+ return loss
164
+
165
+ @staticmethod
166
+ @ensure_contiguous
167
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
168
+ (dX,) = ctx.saved_tensors
169
+ dX = jsd_backward(dX, grad_output)
170
+ return (
171
+ dX,
172
+ None,
173
+ None,
174
+ None,
175
+ None,
176
+ )
@@ -4,13 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous
7
+ from liger_kernel.ops.utils import ensure_contiguous, is_hip
8
8
 
9
9
 
10
10
  def get_num_warps(BLOCK_SIZE):
11
11
  num_warps = 4
12
12
  if BLOCK_SIZE >= 32768:
13
- num_warps = 32
13
+ num_warps = 32 if not is_hip() else 16
14
14
  elif BLOCK_SIZE >= 8192:
15
15
  num_warps = 16
16
16
  elif BLOCK_SIZE >= 2048:
@@ -45,6 +45,7 @@ def _kldiv_kernel_forward(
45
45
  loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
46
46
  loss_stride, # int, output stride
47
47
  n_cols, # int, number of columns in the input tensor
48
+ eps,
48
49
  BLOCK_SIZE: tl.constexpr,
49
50
  log_target: tl.constexpr = False,
50
51
  reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
@@ -56,6 +57,7 @@ def _kldiv_kernel_forward(
56
57
 
57
58
  base_offsets = tl.arange(0, BLOCK_SIZE)
58
59
 
60
+ loss_sum = 0.0
59
61
  for i in range(0, n_cols, BLOCK_SIZE):
60
62
  offsets = i + base_offsets
61
63
  mask = offsets < n_cols
@@ -65,32 +67,33 @@ def _kldiv_kernel_forward(
65
67
  # KL(y_true || y) = y_true * (log(y_true) - log(y))
66
68
  # We compute KL(y_true || y) with y in the log-space
67
69
  if not log_target:
68
- loss = y_true * (tl.log(y_true) - y)
70
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
69
71
  else:
70
72
  loss = tl.exp(y_true) * (y_true - y)
71
73
 
72
74
  if reduction == _REDUCTION_MODE_NONE:
73
75
  tl.store(loss_ptr + offsets, loss, mask=mask)
74
76
  else:
75
- loss = tl.sum(loss, axis=0)
76
- tl.store(loss_ptr, loss)
77
- loss_ptr += 1 # in case of reduction, the output tensor has dimensions [B,], therefore stride is always 1
77
+ loss_sum += tl.sum(loss, axis=0)
78
+
79
+ if reduction != _REDUCTION_MODE_NONE:
80
+ tl.store(loss_ptr, loss_sum)
78
81
 
79
82
 
80
83
  @triton.jit
81
84
  def _kldiv_kernel_backward(
82
- input_ptr,
83
- input_stride,
84
85
  target_ptr,
85
86
  target_stride,
87
+ new_grads_ptr,
88
+ new_grads_stride,
86
89
  n_cols,
87
90
  BLOCK_SIZE: tl.constexpr,
88
91
  log_target: tl.constexpr = False,
89
92
  ):
90
93
  pid = tl.program_id(0).to(tl.int64)
91
94
 
92
- input_ptr += pid * input_stride
93
95
  target_ptr += pid * target_stride
96
+ new_grads_ptr += pid * new_grads_stride
94
97
 
95
98
  offsets = tl.arange(0, BLOCK_SIZE)
96
99
  mask = offsets < n_cols
@@ -106,19 +109,19 @@ def _kldiv_kernel_backward(
106
109
  else:
107
110
  res = -tl.exp(target)
108
111
 
109
- tl.store(input_ptr + offsets, res, mask=mask)
112
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
110
113
 
111
114
 
112
- def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B, S]
113
- B, S = y_pred.shape
115
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
116
+ BT, V = y_pred.shape
114
117
 
115
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
118
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
116
119
  num_warps = get_num_warps(BLOCK_SIZE)
117
120
 
118
- grid = (B,)
121
+ grid = (BT,)
119
122
  reduction = _str_to_reduction_mode[reduction]
120
123
 
121
- out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,)
124
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
122
125
  output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
123
126
 
124
127
  _kldiv_kernel_forward[grid](
@@ -128,7 +131,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
128
131
  y_true.stride(0),
129
132
  output_tensor,
130
133
  output_tensor.stride(0),
131
- S,
134
+ V,
135
+ eps=eps,
132
136
  BLOCK_SIZE=BLOCK_SIZE,
133
137
  num_warps=num_warps,
134
138
  log_target=log_target,
@@ -139,30 +143,30 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
139
143
  # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
140
144
  # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
141
145
  if reduction == _REDUCTION_MODE_BATCHMEAN.value:
142
- return output_tensor.sum() / B
146
+ return output_tensor.sum() / BT
143
147
  elif reduction == _REDUCTION_MODE_SUM.value:
144
148
  return output_tensor.sum(dim=0)
145
149
  elif reduction == _REDUCTION_MODE_MEAN.value:
146
- return output_tensor.mean(dim=0)
150
+ return output_tensor.sum() / (BT * V)
147
151
  else:
148
152
  return output_tensor
149
153
 
150
154
 
151
- def kldiv_backward_triton(input, target, grad_output, log_target):
152
- B, S = input.shape
155
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
156
+ BT, V = target.shape
153
157
 
154
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
158
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
155
159
  num_warps = get_num_warps(BLOCK_SIZE)
156
160
 
157
- grid = (B,)
161
+ grid = (BT,)
158
162
 
159
163
  # We store the gradients in-place in the input tensor
160
164
  _kldiv_kernel_backward[grid](
161
- input,
162
- input.stride(0),
163
165
  target,
164
166
  target.stride(0),
165
- S,
167
+ new_grads,
168
+ new_grads.stride(0),
169
+ V,
166
170
  BLOCK_SIZE=BLOCK_SIZE,
167
171
  num_warps=num_warps,
168
172
  log_target=log_target,
@@ -170,9 +174,9 @@ def kldiv_backward_triton(input, target, grad_output, log_target):
170
174
 
171
175
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
172
176
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
173
- return input
177
+ return new_grads
174
178
 
175
- return input * grad_output
179
+ return new_grads * grad_output
176
180
 
177
181
 
178
182
  class LigerKLDivLossFunction(torch.autograd.Function):
@@ -196,6 +200,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
196
200
  y_true: torch.Tensor,
197
201
  reduction: REDUCTION_LITERAL = "batchmean",
198
202
  log_target: bool = False,
203
+ eps: float = 1e-10,
199
204
  ) -> torch.Tensor:
200
205
  """A forward pass for the KL Divergence Loss.
201
206
 
@@ -205,15 +210,16 @@ class LigerKLDivLossFunction(torch.autograd.Function):
205
210
  y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
206
211
  reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
207
212
  log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
213
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
208
214
 
209
215
  Returns:
210
216
  torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
211
217
  """
212
- ctx.save_for_backward(y_pred, y_true)
218
+ ctx.save_for_backward(y_true)
213
219
  ctx.reduction = reduction
214
220
  ctx.log_target = log_target
215
221
  return kldiv_forward_triton(
216
- y_pred, y_true, log_target=log_target, reduction=reduction
222
+ y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
217
223
  )
218
224
 
219
225
  @staticmethod
@@ -226,22 +232,27 @@ class LigerKLDivLossFunction(torch.autograd.Function):
226
232
  grad_output (torch.Tensor): The gradient of the loss with respect to the output.
227
233
 
228
234
  Returns:
229
- tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
235
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
230
236
  """
231
- y_pred, y_true = ctx.saved_tensors
237
+ (y_true,) = ctx.saved_tensors
238
+
239
+ new_grads = torch.empty_like(y_true)
232
240
 
233
- derivative = kldiv_backward_triton(y_pred, y_true, grad_output, ctx.log_target)
241
+ derivative = kldiv_backward_triton(
242
+ y_true, grad_output, new_grads, ctx.log_target
243
+ )
234
244
 
235
245
  if ctx.reduction == "batchmean":
236
- derivative = derivative / y_pred.shape[0]
246
+ derivative = derivative / y_true.shape[0]
237
247
  elif ctx.reduction == "sum" or ctx.reduction == "none":
238
248
  pass
239
249
  elif ctx.reduction == "mean":
240
- derivative = derivative / (y_pred.shape[0] * y_pred.shape[1])
250
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
241
251
 
242
252
  return (
243
253
  derivative,
244
254
  None,
245
255
  None,
246
256
  None,
257
+ None,
247
258
  )