liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+
10
+ @triton.jit
11
+ def _jsd_kernel(
12
+ X_ptr, # input in logspace, X = log Q
13
+ X_stride,
14
+ Y_ptr, # ground truth in logspace, Y = log P
15
+ Y_stride,
16
+ loss_ptr,
17
+ loss_stride,
18
+ dX_ptr,
19
+ dX_stride,
20
+ label_ptr,
21
+ beta,
22
+ n_non_ignore: int,
23
+ ignore_index: tl.constexpr,
24
+ n_cols,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ HAS_LABEL: tl.constexpr,
27
+ ):
28
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
29
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
30
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
31
+ # grad_x_i = 0.5 * Q * (X - log_M)
32
+ pid = tl.program_id(0).to(tl.int64)
33
+ X_ptr += pid * X_stride
34
+ dX_ptr += pid * dX_stride
35
+ Y_ptr += pid * Y_stride
36
+ loss_ptr += pid * loss_stride
37
+ label_ptr += pid
38
+
39
+ if HAS_LABEL:
40
+ label = tl.load(label_ptr)
41
+ if label == ignore_index:
42
+ for i in range(0, n_cols, BLOCK_SIZE):
43
+ offsets = i + tl.arange(0, BLOCK_SIZE)
44
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
45
+ return
46
+
47
+ for i in range(0, n_cols, BLOCK_SIZE):
48
+ offsets = i + tl.arange(0, BLOCK_SIZE)
49
+ mask = offsets < n_cols
50
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+
53
+ Q = tl.exp(X)
54
+ P = tl.exp(Y)
55
+ M = beta * P + (1 - beta) * Q
56
+ log_M = tl.log(M)
57
+
58
+ loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
+ # reduction == "batchmean"
60
+ loss = loss / n_non_ignore
61
+ tl.store(loss_ptr + offsets, loss, mask=mask)
62
+
63
+ dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
+ tl.store(dX_ptr + offsets, dX, mask=mask)
65
+
66
+
67
+ MAX_FUSED_SIZE = 65536
68
+
69
+
70
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
71
+ BT, V = _input.shape
72
+ n_rows = BT
73
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
74
+ # non reduction loss
75
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
76
+ dX = torch.empty_like(_input)
77
+
78
+ if has_label:
79
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
80
+ else:
81
+ n_non_ignore = BT
82
+
83
+ _jsd_kernel[(n_rows,)](
84
+ X_ptr=_input, # input in logspace, X = log Q
85
+ X_stride=_input.stride(-2),
86
+ Y_ptr=target, # ground truth in logspace, Y = log P
87
+ Y_stride=target.stride(-2),
88
+ loss_ptr=loss,
89
+ loss_stride=loss.stride(-2),
90
+ dX_ptr=dX,
91
+ dX_stride=dX.stride(-2),
92
+ label_ptr=(
93
+ shift_labels if has_label else torch.empty(1, device=_input.device)
94
+ ), # dummy ptr if no label
95
+ beta=beta,
96
+ n_non_ignore=n_non_ignore,
97
+ ignore_index=ignore_index,
98
+ n_cols=V,
99
+ BLOCK_SIZE=BLOCK_SIZE,
100
+ HAS_LABEL=has_label,
101
+ )
102
+
103
+ loss = torch.sum(loss)
104
+ return loss.to(_input.dtype), dX
105
+
106
+
107
+ def jsd_backward(dX, grad_output):
108
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
109
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
110
+ return dX
111
+ else:
112
+ return grad_output * dX
113
+
114
+
115
+ class LigerJSDFunction(torch.autograd.Function):
116
+ r"""
117
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
118
+ .. math::
119
+ JSD(\beta)(P || Q)
120
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
121
+
122
+ .. note::
123
+ As all the other losses in PyTorch, this function expects the first argument,
124
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
125
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
126
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
127
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
128
+ """
129
+
130
+ @staticmethod
131
+ @ensure_contiguous
132
+ def forward(
133
+ ctx,
134
+ _input: torch.Tensor,
135
+ target: torch.Tensor,
136
+ shift_labels: Optional[torch.Tensor] = None,
137
+ beta: float = 0.5,
138
+ ignore_index: int = -100,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
+ beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
146
+ ignore_index (int): the index to ignore. Default: -100
147
+
148
+ Returns:
149
+ loss (torch.Tensor): generalized JSD
150
+ """
151
+ has_label = False
152
+ if shift_labels is not None:
153
+ assert shift_labels.shape == (
154
+ _input.shape[0],
155
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
156
+ shift_labels = shift_labels.contiguous()
157
+ has_label = True
158
+
159
+ loss, dX = jsd_forward(
160
+ _input, target, shift_labels, beta, ignore_index, has_label
161
+ )
162
+ ctx.save_for_backward(dX)
163
+ return loss
164
+
165
+ @staticmethod
166
+ @ensure_contiguous
167
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
168
+ (dX,) = ctx.saved_tensors
169
+ dX = jsd_backward(dX, grad_output)
170
+ return (
171
+ dX,
172
+ None,
173
+ None,
174
+ None,
175
+ None,
176
+ )
@@ -4,13 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous
7
+ from liger_kernel.ops.utils import ensure_contiguous, is_hip
8
8
 
9
9
 
10
10
  def get_num_warps(BLOCK_SIZE):
11
11
  num_warps = 4
12
12
  if BLOCK_SIZE >= 32768:
13
- num_warps = 32
13
+ num_warps = 32 if not is_hip() else 16
14
14
  elif BLOCK_SIZE >= 8192:
15
15
  num_warps = 16
16
16
  elif BLOCK_SIZE >= 2048:
@@ -10,6 +10,7 @@ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddec
10
10
  Modifications made by Yanning Chen, 2024.
11
11
  """
12
12
 
13
+ import math
13
14
  import operator
14
15
 
15
16
  import torch
@@ -20,6 +21,7 @@ from liger_kernel.ops.utils import (
20
21
  calculate_settings,
21
22
  compare_version,
22
23
  ensure_contiguous,
24
+ torch_to_triton_dtype,
23
25
  )
24
26
 
25
27
  if compare_version("triton", operator.ge, "3.0.0"):
@@ -84,6 +86,10 @@ def _rms_norm_forward_kernel(
84
86
  W_row = W_row.to(tl.float32)
85
87
  X_row = X_row.to(tl.float32)
86
88
 
89
+ if casting_mode == _CASTING_MODE_NONE:
90
+ eps = eps.to(X_row_dtype)
91
+ offset = offset.to(X_row_dtype)
92
+
87
93
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
88
94
  rstd = rsqrt(mean_square + eps)
89
95
 
@@ -100,6 +106,9 @@ def _rms_norm_forward_kernel(
100
106
 
101
107
  Y_row = X_row * (offset + W_row)
102
108
 
109
+ if casting_mode == _CASTING_MODE_GEMMA:
110
+ Y_row = Y_row.to(X_row_dtype)
111
+
103
112
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
104
113
 
105
114
 
@@ -107,16 +116,21 @@ def _rms_norm_forward_kernel(
107
116
  def _rms_norm_backward_kernel(
108
117
  dY_ptr,
109
118
  dY_row_stride,
119
+ dX_ptr,
120
+ dX_row_stride,
110
121
  X_ptr,
111
122
  X_row_stride,
123
+ X_dtype: tl.constexpr,
112
124
  W_ptr,
113
125
  W_row_stride,
114
126
  RSTD_ptr,
115
127
  RSTD_row_stride,
116
128
  dW_ptr,
117
129
  dW_row_stride,
130
+ n_rows,
118
131
  n_cols,
119
132
  offset,
133
+ rows_per_program: tl.constexpr,
120
134
  casting_mode: tl.constexpr,
121
135
  BLOCK_SIZE: tl.constexpr,
122
136
  ):
@@ -125,54 +139,63 @@ def _rms_norm_backward_kernel(
125
139
  dw = sum(dy * (x / RMS)). summation over BxT dimension
126
140
  """
127
141
 
128
- row_idx = tl.program_id(0)
142
+ row_block_id = tl.program_id(0)
143
+ row_start = row_block_id * rows_per_program
144
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
145
  col_offsets = tl.arange(0, BLOCK_SIZE)
130
146
  mask = col_offsets < n_cols
131
147
 
132
- dY_ptr += row_idx * dY_row_stride
133
- X_ptr += row_idx * X_row_stride
134
- RSTD_ptr += row_idx * RSTD_row_stride
135
- dW_ptr += row_idx * dW_row_stride
148
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
136
149
 
137
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
138
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
139
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
140
- original_x_dtype = X_row.dtype
150
+ dY_ptr += row_start * dY_row_stride
151
+ dX_ptr += row_start * dX_row_stride
141
152
 
142
- # Get cached rms
143
- rstd_row = tl.load(RSTD_ptr)
153
+ X_ptr += row_start * X_row_stride
154
+ RSTD_ptr += row_start
144
155
 
156
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
145
157
  W_row = W_row + offset
146
158
 
147
- X_row = X_row.to(tl.float32)
159
+ for _ in range(row_start, row_end):
160
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
161
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
148
162
 
149
- # Different bacward graphs for different casting modes
150
- if casting_mode == _CASTING_MODE_LLAMA:
151
- m = (dY_row * W_row).to(tl.float32)
163
+ # Get cached rms
164
+ rstd_row = tl.load(RSTD_ptr)
152
165
 
153
- elif casting_mode == _CASTING_MODE_GEMMA:
154
- dY_row, W_row = (
155
- dY_row.to(tl.float32),
156
- W_row.to(tl.float32),
157
- )
166
+ X_row = X_row.to(tl.float32)
158
167
 
159
- m = dY_row * W_row
168
+ # Different bacward graphs for different casting modes
169
+ if casting_mode == _CASTING_MODE_LLAMA:
170
+ m = (dY_row * W_row).to(tl.float32)
160
171
 
161
- dX_row = rstd_row * m
172
+ elif casting_mode == _CASTING_MODE_GEMMA:
173
+ dY_row = dY_row.to(tl.float32)
174
+ m = dY_row * W_row
175
+ else:
176
+ m = dY_row * W_row
162
177
 
163
- dX_row += (rstd_row) * (
164
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
165
- )
178
+ dX_row = rstd_row * m
166
179
 
167
- # calculate the gradient of W
168
- if casting_mode == _CASTING_MODE_LLAMA:
169
- dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
170
- else:
171
- # here X_row is already in fp32 (see previous if block)
172
- dW_row = dY_row * (X_row * rstd_row)
180
+ dX_row += (rstd_row) * (
181
+ -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182
+ )
183
+
184
+ # calculate the gradient of W
185
+ if casting_mode == _CASTING_MODE_LLAMA:
186
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
187
+ else:
188
+ # here X_row is already in fp32 (see previous if block)
189
+ dW_row += dY_row * (X_row * rstd_row)
190
+
191
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
192
+
193
+ dY_ptr += dY_row_stride
194
+ dX_ptr += dX_row_stride
195
+ X_ptr += X_row_stride
196
+ RSTD_ptr += RSTD_row_stride
173
197
 
174
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
175
- tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
198
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
176
199
 
177
200
 
178
201
  _str_to_casting_mode = {
@@ -233,36 +256,53 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
233
256
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
234
257
 
235
258
 
236
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
259
+ def rms_norm_backward(
260
+ dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
+ ):
237
262
  shape = dY.shape
238
263
  dim = shape[-1]
239
264
  dY = dY.view(-1, dim)
240
265
  n_rows, n_cols = dY.shape
241
- dW = torch.empty_like(
242
- X,
243
- dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
244
- )
245
266
 
246
- # Here we use dY to store the value of dX to save memory
247
- _rms_norm_backward_kernel[(n_rows,)](
267
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
268
+ # fp32 for numerical stability especially.
269
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
270
+
271
+ if n_cols > BLOCK_SIZE:
272
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
273
+ rows_per_program = math.ceil(n_rows / sm_count)
274
+ grid = (sm_count,)
275
+
276
+ if in_place is True:
277
+ dX = dY
278
+ else:
279
+ dX = torch.zeros_like(dY)
280
+
281
+ _rms_norm_backward_kernel[grid](
248
282
  dY,
249
283
  dY.stride(0),
284
+ dX,
285
+ dX.stride(0),
250
286
  X,
251
287
  X.stride(0),
288
+ torch_to_triton_dtype[X.dtype],
252
289
  W,
253
290
  W.stride(0),
254
291
  RSTD,
255
292
  RSTD.stride(0),
256
- dW,
257
- dW.stride(0),
293
+ _dW,
294
+ _dW.stride(0),
295
+ n_rows,
258
296
  n_cols,
259
297
  offset,
298
+ rows_per_program,
260
299
  casting_mode,
261
300
  BLOCK_SIZE=BLOCK_SIZE,
262
301
  num_warps=num_warps,
263
302
  )
264
- dX = dY.view(*shape)
265
- dW = torch.sum(dW, dim=0).to(W.dtype)
303
+ dX = dX.view(*shape)
304
+ dW = _dW.sum(dim=0).to(W.dtype)
305
+
266
306
  return dX, dW
267
307
 
268
308
 
@@ -282,11 +322,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
282
322
  - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
283
323
  - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
284
324
  - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
325
+
326
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
327
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
328
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
285
329
  """
286
330
 
287
331
  @staticmethod
288
332
  @ensure_contiguous
289
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
333
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
290
334
  """
291
335
  X: (B, T, H) or (BxT, H)
292
336
  W: (H,)
@@ -296,6 +340,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
296
340
  )
297
341
  ctx.offset = offset
298
342
  ctx.casting_mode = casting_mode
343
+ ctx.in_place = in_place
299
344
  ctx.BLOCK_SIZE = BLOCK_SIZE
300
345
  ctx.num_warps = num_warps
301
346
  ctx.save_for_backward(X, W, RSTD)
@@ -317,5 +362,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
317
362
  ctx.casting_mode,
318
363
  ctx.BLOCK_SIZE,
319
364
  ctx.num_warps,
365
+ ctx.in_place,
320
366
  )
321
- return dX, dW, None, None, None
367
+ return dX, dW, None, None, None, None
@@ -14,7 +14,7 @@ def silu(x):
14
14
  def _swiglu_forward_kernel(
15
15
  a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
16
  ):
17
- program_id = tl.program_id(0).cast(tl.int64)
17
+ program_id = tl.program_id(0).to(tl.int64)
18
18
 
19
19
  # locate start index
20
20
  a_ptr += program_id * stride
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
35
35
  def _swiglu_backward_kernel(
36
36
  dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
37
  ):
38
- program_id = tl.program_id(0).cast(tl.int64)
38
+ program_id = tl.program_id(0).to(tl.int64)
39
39
 
40
40
  # locate start index
41
41
  dc_ptr += program_id * stride
liger_kernel/ops/utils.py CHANGED
@@ -12,13 +12,19 @@ Modifications made by Yanning Chen, 2024.
12
12
 
13
13
  import functools
14
14
  import importlib
15
+ import operator
15
16
  from typing import Callable
16
17
 
17
18
  import torch
18
19
  import triton
20
+ import triton.language as tl
19
21
  from packaging.version import Version
20
22
 
21
23
 
24
+ def is_hip() -> bool:
25
+ return torch.version.hip is not None
26
+
27
+
22
28
  def ensure_contiguous(fn):
23
29
  @functools.wraps(fn)
24
30
  def wrapper(ctx, *args, **kwargs):
@@ -45,7 +51,7 @@ def calculate_settings(n):
45
51
 
46
52
  num_warps = 4
47
53
  if BLOCK_SIZE >= 32768:
48
- num_warps = 32
54
+ num_warps = 32 if not is_hip() else 16
49
55
  elif BLOCK_SIZE >= 8192:
50
56
  num_warps = 16
51
57
  elif BLOCK_SIZE >= 2048:
@@ -60,3 +66,58 @@ def compare_version(package: str, operator: Callable, target: str):
60
66
  return False
61
67
  pkg_version = Version(pkg.__version__)
62
68
  return operator(pkg_version, Version(target))
69
+
70
+
71
+ def get_amp_custom_fwd_bwd() -> Callable:
72
+ if compare_version("torch", operator.ge, "2.4.0"):
73
+ return (
74
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
76
+ )
77
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
+
79
+
80
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
81
+
82
+
83
+ torch_to_triton_dtype = {
84
+ torch.float32: tl.float32,
85
+ torch.float16: tl.float16,
86
+ torch.bfloat16: tl.bfloat16,
87
+ }
88
+
89
+
90
+ @triton.jit
91
+ def element_mul_kernel(
92
+ X_ptr,
93
+ X_stride,
94
+ grad_output_ptr,
95
+ n_cols,
96
+ BLOCK_SIZE: tl.constexpr,
97
+ ):
98
+ """
99
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
100
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
101
+
102
+ Parameters:
103
+ X_ptr: Pointer to the input tensor.
104
+ X_stride (int): The stride of the input tensor.
105
+ grad_output_ptr: Pointer to the gradient output value.
106
+ n_cols (int): The number of columns in the input tensor.
107
+ BLOCK_SIZE (int): The block size for Triton operations.
108
+ """
109
+
110
+ # Get the program ID and convert it to int64 to avoid overflow
111
+ program_id = tl.program_id(0).to(tl.int64)
112
+
113
+ # Locate the start index
114
+ X_ptr += program_id * X_stride
115
+
116
+ # Load the gradient output value
117
+ grad_output = tl.load(grad_output_ptr)
118
+
119
+ # Perform the element-wise multiplication
120
+ for i in range(0, n_cols, BLOCK_SIZE):
121
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
122
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
123
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
@@ -5,7 +5,9 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noq
5
5
  from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6
6
  LigerFusedLinearCrossEntropyLoss,
7
7
  )
8
+ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
8
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10
+ from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
9
11
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
10
12
  from liger_kernel.transformers.monkey_patch import ( # noqa: F401
11
13
  _apply_liger_kernel,
@@ -15,6 +17,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
15
17
  apply_liger_kernel_to_llama,
16
18
  apply_liger_kernel_to_mistral,
17
19
  apply_liger_kernel_to_mixtral,
20
+ apply_liger_kernel_to_mllama,
18
21
  apply_liger_kernel_to_phi3,
19
22
  apply_liger_kernel_to_qwen2,
20
23
  apply_liger_kernel_to_qwen2_vl,
@@ -1,21 +1,53 @@
1
- from torch.nn import CrossEntropyLoss
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
6
 
5
7
 
6
- class LigerCrossEntropyLoss(CrossEntropyLoss):
7
- def __init__(self, *args, **kwargs):
8
- super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
- assert (self.label_smoothing >= 0) and (
10
- self.label_smoothing <= 1
11
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
12
- assert self.reduction in {
8
+ class LigerCrossEntropyLoss(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ ignore_index: int = -100,
12
+ lse_square_scale: float = 0.0,
13
+ label_smoothing: float = 0.0,
14
+ reduction: str = "mean",
15
+ softcap: Optional[float] = None,
16
+ return_z_loss: bool = False,
17
+ ):
18
+ super().__init__()
19
+ assert (label_smoothing >= 0) and (
20
+ label_smoothing <= 1
21
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
+ assert (label_smoothing >= 0) and (
23
+ label_smoothing <= 1
24
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
25
+ assert reduction in {
13
26
  "mean",
14
27
  "sum",
15
28
  "none",
16
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
29
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
+ assert (
31
+ softcap is None or softcap > 0
32
+ ), f"softcap must greater than 0.0 or None. Got: {softcap}"
33
+ self.ignore_index = ignore_index
34
+ self.lse_square_scale = lse_square_scale
35
+ self.label_smoothing = label_smoothing
36
+ self.reduction = reduction
37
+ self.softcap = softcap
38
+ self.return_z_loss = return_z_loss
17
39
 
18
- def forward(self, _input, target):
19
- return LigerCrossEntropyFunction.apply(
20
- _input, target, self.ignore_index, self.label_smoothing, self.reduction
40
+ def forward(self, _input: torch.Tensor, target: torch.Tensor):
41
+ loss, z_loss = LigerCrossEntropyFunction.apply(
42
+ _input,
43
+ target,
44
+ self.ignore_index,
45
+ self.lse_square_scale,
46
+ self.label_smoothing,
47
+ self.reduction,
48
+ self.softcap,
49
+ self.return_z_loss,
21
50
  )
51
+ if not self.return_z_loss:
52
+ return loss
53
+ return loss, z_loss
@@ -1,8 +1,13 @@
1
+ from typing import Optional
2
+
1
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
2
4
  from liger_kernel.ops.fused_linear_cross_entropy import (
3
5
  LigerFusedLinearCrossEntropyFunction,
4
6
  )
7
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
5
8
  from liger_kernel.ops.geglu import LigerGELUMulFunction
9
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
+ from liger_kernel.ops.jsd import LigerJSDFunction
6
11
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
7
12
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
8
13
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
@@ -10,10 +15,42 @@ from liger_kernel.ops.rope import LigerRopeFunction
10
15
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
11
16
 
12
17
  liger_swiglu = LigerSiLUMulFunction.apply
13
- liger_cross_entropy = LigerCrossEntropyFunction.apply
14
18
  liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
15
19
  liger_geglu = LigerGELUMulFunction.apply
16
20
  liger_rms_norm = LigerRMSNormFunction.apply
17
21
  liger_rope = LigerRopeFunction.apply
18
22
  liger_layer_norm = LigerLayerNormFunction.apply
19
23
  liger_kl_div = LigerKLDivLossFunction.apply
24
+ liger_jsd = LigerJSDFunction.apply
25
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
26
+ liger_group_norm = LigerGroupNormFunction.apply
27
+
28
+
29
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
30
+ # `weight` and `size_average` are placeholders and not implemented yet
31
+ def liger_cross_entropy(
32
+ input,
33
+ target,
34
+ weight=None,
35
+ size_average=None,
36
+ ignore_index: int = -100,
37
+ reduce=None,
38
+ reduction: str = "mean",
39
+ label_smoothing: float = 0.0,
40
+ lse_square_scale: float = 0.0,
41
+ softcap: Optional[float] = None,
42
+ return_z_loss: bool = False,
43
+ ):
44
+ loss, z_loss = LigerCrossEntropyFunction.apply(
45
+ input,
46
+ target,
47
+ ignore_index,
48
+ lse_square_scale,
49
+ label_smoothing,
50
+ reduction,
51
+ softcap,
52
+ return_z_loss,
53
+ )
54
+ if not return_z_loss:
55
+ return loss
56
+ return loss, z_loss