liger-kernel 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +12 -9
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +2 -2
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/functional.py +4 -0
  13. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  14. liger_kernel/transformers/jsd.py +75 -0
  15. liger_kernel/transformers/model/gemma.py +124 -1
  16. liger_kernel/transformers/model/llama.py +135 -4
  17. liger_kernel/transformers/model/mistral.py +3 -0
  18. liger_kernel/transformers/model/mixtral.py +153 -2
  19. liger_kernel/transformers/model/mllama.py +274 -0
  20. liger_kernel/transformers/model/phi3.py +140 -2
  21. liger_kernel/transformers/model/qwen2.py +123 -2
  22. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  23. liger_kernel/transformers/monkey_patch.py +158 -7
  24. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +60 -28
  25. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  26. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  27. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  28. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  29. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  30. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+
6
+ from liger_kernel.ops.jsd import _jsd_kernel
7
+ from liger_kernel.ops.utils import (
8
+ amp_custom_bwd,
9
+ amp_custom_fwd,
10
+ element_mul_kernel,
11
+ is_hip,
12
+ )
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
16
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
17
+ MAX_FUSED_SIZE = 65536 // 2
18
+
19
+
20
+ def fused_linear_jsd_forward(
21
+ student_input,
22
+ student_weight,
23
+ teacher_input,
24
+ teacher_weight,
25
+ shift_labels,
26
+ jsd_beta,
27
+ ignore_index,
28
+ has_label,
29
+ temperature,
30
+ ):
31
+ device = student_input.device
32
+ dtype = student_input.dtype
33
+
34
+ # inputs have shape: BT x H
35
+ # materialized activations will have shape: BT x V
36
+ # the increase in memory = BT x V
37
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
38
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
39
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
40
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
41
+ BT, H = student_input.shape
42
+ V = student_weight.shape[0]
43
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
+
45
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
+ chunk_size = triton.next_power_of_2(
47
+ triton.cdiv(BT, inc_factor)
48
+ ) # (BT + inc_factor - 1) // inc_factor
49
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
+
51
+ grad_weight = (
52
+ torch.zeros_like(student_weight, device=device)
53
+ if student_weight.requires_grad
54
+ else None
55
+ )
56
+ grad_input = torch.zeros_like(student_input)
57
+ # we use fp32 for loss accumulator
58
+ loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
59
+
60
+ if has_label:
61
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
62
+ else:
63
+ n_non_ignore = BT
64
+
65
+ for chunk_id in range(num_chunks):
66
+ start_idx = chunk_id * chunk_size
67
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
68
+
69
+ # chunk both inputs, shape: chunk_size x H
70
+ student_input_chunk = student_input[start_idx:end_idx]
71
+ teacher_input_chunk = teacher_input[start_idx:end_idx]
72
+
73
+ # shape: chunk_size x V
74
+ # For anything starting from logits to the final JSD loss, we do computation
75
+ # in FP32 to avoid losing numerical stability.
76
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
+ torch.float32
78
+ )
79
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
+ torch.float32
81
+ )
82
+ chunk_n_rows = student_logits_chunk.shape[0]
83
+
84
+ # unreduced loss
85
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
86
+ # log-softmax with temperature
87
+ student_logits_chunk = student_logits_chunk / temperature
88
+ teacher_logits_chunk = teacher_logits_chunk / temperature
89
+ student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
90
+ teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
91
+
92
+ # ensure _input and target are contiguous
93
+ student_prob_chunk = student_prob_chunk.contiguous()
94
+ teacher_prob_chunk = teacher_prob_chunk.contiguous()
95
+
96
+ # Here we calculate the gradient of prob_chunk in place so we can save memory.
97
+ _jsd_kernel[(chunk_n_rows,)](
98
+ X_ptr=student_prob_chunk,
99
+ X_stride=student_prob_chunk.stride(-2),
100
+ Y_ptr=teacher_prob_chunk,
101
+ Y_stride=teacher_prob_chunk.stride(-2),
102
+ loss_ptr=loss_1d_slice,
103
+ loss_stride=loss_1d_slice.stride(-2),
104
+ dX_ptr=student_prob_chunk,
105
+ dX_stride=student_prob_chunk.stride(-2),
106
+ label_ptr=(
107
+ shift_labels[start_idx:end_idx]
108
+ if has_label
109
+ else torch.empty(1, device=device)
110
+ ), # dummy ptr if no label
111
+ beta=jsd_beta,
112
+ n_non_ignore=n_non_ignore,
113
+ ignore_index=ignore_index,
114
+ n_cols=V,
115
+ BLOCK_SIZE=BLOCK_SIZE,
116
+ HAS_LABEL=has_label,
117
+ )
118
+ loss_1d[start_idx:end_idx] = loss_1d_slice
119
+ # gradients of prob_chunk in place, shape: chunk_size x V
120
+ # gradients of logits_chunk in place, shape: chunk_size x V
121
+ student_logits_chunk = (
122
+ student_prob_chunk
123
+ - torch.softmax(student_logits_chunk, dim=-1)
124
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
+ student_prob_chunk.shape
126
+ )
127
+ ) / temperature
128
+ # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
+ # w.r.t. `lm_head` which should be computed in original dtype
130
+ student_logits_chunk = student_logits_chunk.to(dtype)
131
+ grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
132
+
133
+ if grad_weight is not None:
134
+ grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
135
+
136
+ loss = torch.sum(loss_1d)
137
+ return loss, grad_input, grad_weight
138
+
139
+
140
+ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
141
+ # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
142
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
143
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
144
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
145
+ BT, H = grad_input.shape
146
+ n_rows = BT
147
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
148
+
149
+ element_mul_kernel[(n_rows,)](
150
+ grad_input,
151
+ grad_input.stride(-2),
152
+ grad_output,
153
+ H,
154
+ BLOCK_SIZE=BLOCK_SIZE,
155
+ num_warps=32 if not is_hip() else 16,
156
+ )
157
+
158
+ # handle grad_weight
159
+ if grad_weight is not None:
160
+ V, H = grad_weight.shape
161
+ n_rows = V
162
+
163
+ element_mul_kernel[(n_rows,)](
164
+ grad_weight,
165
+ grad_weight.stride(-2),
166
+ grad_output,
167
+ H,
168
+ BLOCK_SIZE=BLOCK_SIZE,
169
+ num_warps=32 if not is_hip() else 16,
170
+ )
171
+
172
+ return grad_input, grad_weight
173
+
174
+
175
+ class LigerFusedLinearJSDFunction(torch.autograd.Function):
176
+ """
177
+ Fusing the last linear layer with generalized JSD
178
+
179
+ Handle the forward and backward pass of the final linear layer via JSD by avoiding
180
+ the materialization of the large logits tensor. Since JSD is the last layer, we can
181
+ compute the gradient at the forward pass.
182
+ """
183
+
184
+ @staticmethod
185
+ @amp_custom_fwd
186
+ def forward(
187
+ ctx,
188
+ student_input: torch.Tensor,
189
+ student_weight: torch.Tensor,
190
+ teacher_input: torch.Tensor,
191
+ teacher_weight: torch.Tensor,
192
+ shift_labels: Optional[torch.Tensor] = None,
193
+ jsd_beta: float = 0.5,
194
+ ignore_index: int = -100,
195
+ temperature: float = 1.0,
196
+ ):
197
+ """
198
+ Args:
199
+
200
+ student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
201
+ student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
202
+ teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
+ teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
+ jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
206
+ ignore_index (int): the index to ignore. Default: -100
207
+ temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
+
209
+ Returns:
210
+ loss (torch.Tensor): generalized JSD
211
+ """
212
+ has_label = False
213
+ if shift_labels is not None:
214
+ assert shift_labels.shape == (
215
+ teacher_input.shape[0],
216
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
217
+ shift_labels = shift_labels.contiguous()
218
+ has_label = True
219
+
220
+ loss, grad_input, grad_weight = fused_linear_jsd_forward(
221
+ student_input,
222
+ student_weight,
223
+ teacher_input,
224
+ teacher_weight,
225
+ shift_labels,
226
+ jsd_beta,
227
+ ignore_index,
228
+ has_label,
229
+ temperature,
230
+ )
231
+ # downcast to dtype and store for backward
232
+ ctx.save_for_backward(
233
+ grad_input.detach(),
234
+ grad_weight.detach() if grad_weight is not None else None,
235
+ )
236
+ return loss
237
+
238
+ @staticmethod
239
+ @amp_custom_bwd
240
+ def backward(ctx, grad_output):
241
+ (grad_input, grad_weight) = ctx.saved_tensors
242
+ grad_input, grad_weight = fused_linear_jsd_backward(
243
+ grad_output, grad_input, grad_weight
244
+ )
245
+ return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -25,7 +25,7 @@ else:
25
25
  def _geglu_tanh_forward_kernel(
26
26
  a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
27
  ):
28
- program_id = tl.program_id(0).cast(tl.int64)
28
+ program_id = tl.program_id(0).to(tl.int64)
29
29
 
30
30
  # locate start index
31
31
  a += program_id * stride
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
52
52
  def _geglu_tanh_backward_kernel(
53
53
  dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
54
  ):
55
- program_id = tl.program_id(0).cast(tl.int64)
55
+ program_id = tl.program_id(0).to(tl.int64)
56
56
 
57
57
  # locate start index
58
58
  dc += program_id * stride
@@ -0,0 +1,176 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+
10
+ @triton.jit
11
+ def _jsd_kernel(
12
+ X_ptr, # input in logspace, X = log Q
13
+ X_stride,
14
+ Y_ptr, # ground truth in logspace, Y = log P
15
+ Y_stride,
16
+ loss_ptr,
17
+ loss_stride,
18
+ dX_ptr,
19
+ dX_stride,
20
+ label_ptr,
21
+ beta,
22
+ n_non_ignore: int,
23
+ ignore_index: tl.constexpr,
24
+ n_cols,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ HAS_LABEL: tl.constexpr,
27
+ ):
28
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
29
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
30
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
31
+ # grad_x_i = 0.5 * Q * (X - log_M)
32
+ pid = tl.program_id(0).to(tl.int64)
33
+ X_ptr += pid * X_stride
34
+ dX_ptr += pid * dX_stride
35
+ Y_ptr += pid * Y_stride
36
+ loss_ptr += pid * loss_stride
37
+ label_ptr += pid
38
+
39
+ if HAS_LABEL:
40
+ label = tl.load(label_ptr)
41
+ if label == ignore_index:
42
+ for i in range(0, n_cols, BLOCK_SIZE):
43
+ offsets = i + tl.arange(0, BLOCK_SIZE)
44
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
45
+ return
46
+
47
+ for i in range(0, n_cols, BLOCK_SIZE):
48
+ offsets = i + tl.arange(0, BLOCK_SIZE)
49
+ mask = offsets < n_cols
50
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+
53
+ Q = tl.exp(X)
54
+ P = tl.exp(Y)
55
+ M = beta * P + (1 - beta) * Q
56
+ log_M = tl.log(M)
57
+
58
+ loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
+ # reduction == "batchmean"
60
+ loss = loss / n_non_ignore
61
+ tl.store(loss_ptr + offsets, loss, mask=mask)
62
+
63
+ dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
+ tl.store(dX_ptr + offsets, dX, mask=mask)
65
+
66
+
67
+ MAX_FUSED_SIZE = 65536
68
+
69
+
70
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
71
+ BT, V = _input.shape
72
+ n_rows = BT
73
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
74
+ # non reduction loss
75
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
76
+ dX = torch.empty_like(_input)
77
+
78
+ if has_label:
79
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
80
+ else:
81
+ n_non_ignore = BT
82
+
83
+ _jsd_kernel[(n_rows,)](
84
+ X_ptr=_input, # input in logspace, X = log Q
85
+ X_stride=_input.stride(-2),
86
+ Y_ptr=target, # ground truth in logspace, Y = log P
87
+ Y_stride=target.stride(-2),
88
+ loss_ptr=loss,
89
+ loss_stride=loss.stride(-2),
90
+ dX_ptr=dX,
91
+ dX_stride=dX.stride(-2),
92
+ label_ptr=(
93
+ shift_labels if has_label else torch.empty(1, device=_input.device)
94
+ ), # dummy ptr if no label
95
+ beta=beta,
96
+ n_non_ignore=n_non_ignore,
97
+ ignore_index=ignore_index,
98
+ n_cols=V,
99
+ BLOCK_SIZE=BLOCK_SIZE,
100
+ HAS_LABEL=has_label,
101
+ )
102
+
103
+ loss = torch.sum(loss)
104
+ return loss.to(_input.dtype), dX
105
+
106
+
107
+ def jsd_backward(dX, grad_output):
108
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
109
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
110
+ return dX
111
+ else:
112
+ return grad_output * dX
113
+
114
+
115
+ class LigerJSDFunction(torch.autograd.Function):
116
+ r"""
117
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
118
+ .. math::
119
+ JSD(\beta)(P || Q)
120
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
121
+
122
+ .. note::
123
+ As all the other losses in PyTorch, this function expects the first argument,
124
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
125
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
126
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
127
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
128
+ """
129
+
130
+ @staticmethod
131
+ @ensure_contiguous
132
+ def forward(
133
+ ctx,
134
+ _input: torch.Tensor,
135
+ target: torch.Tensor,
136
+ shift_labels: Optional[torch.Tensor] = None,
137
+ beta: float = 0.5,
138
+ ignore_index: int = -100,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
+ beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
146
+ ignore_index (int): the index to ignore. Default: -100
147
+
148
+ Returns:
149
+ loss (torch.Tensor): generalized JSD
150
+ """
151
+ has_label = False
152
+ if shift_labels is not None:
153
+ assert shift_labels.shape == (
154
+ _input.shape[0],
155
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
156
+ shift_labels = shift_labels.contiguous()
157
+ has_label = True
158
+
159
+ loss, dX = jsd_forward(
160
+ _input, target, shift_labels, beta, ignore_index, has_label
161
+ )
162
+ ctx.save_for_backward(dX)
163
+ return loss
164
+
165
+ @staticmethod
166
+ @ensure_contiguous
167
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
168
+ (dX,) = ctx.saved_tensors
169
+ dX = jsd_backward(dX, grad_output)
170
+ return (
171
+ dX,
172
+ None,
173
+ None,
174
+ None,
175
+ None,
176
+ )
@@ -4,13 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous
7
+ from liger_kernel.ops.utils import ensure_contiguous, is_hip
8
8
 
9
9
 
10
10
  def get_num_warps(BLOCK_SIZE):
11
11
  num_warps = 4
12
12
  if BLOCK_SIZE >= 32768:
13
- num_warps = 32
13
+ num_warps = 32 if not is_hip() else 16
14
14
  elif BLOCK_SIZE >= 8192:
15
15
  num_warps = 16
16
16
  elif BLOCK_SIZE >= 2048:
@@ -10,6 +10,7 @@ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddec
10
10
  Modifications made by Yanning Chen, 2024.
11
11
  """
12
12
 
13
+ import math
13
14
  import operator
14
15
 
15
16
  import torch
@@ -20,6 +21,7 @@ from liger_kernel.ops.utils import (
20
21
  calculate_settings,
21
22
  compare_version,
22
23
  ensure_contiguous,
24
+ torch_to_triton_dtype,
23
25
  )
24
26
 
25
27
  if compare_version("triton", operator.ge, "3.0.0"):
@@ -84,6 +86,10 @@ def _rms_norm_forward_kernel(
84
86
  W_row = W_row.to(tl.float32)
85
87
  X_row = X_row.to(tl.float32)
86
88
 
89
+ if casting_mode == _CASTING_MODE_NONE:
90
+ eps = eps.to(X_row_dtype)
91
+ offset = offset.to(X_row_dtype)
92
+
87
93
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
88
94
  rstd = rsqrt(mean_square + eps)
89
95
 
@@ -100,6 +106,9 @@ def _rms_norm_forward_kernel(
100
106
 
101
107
  Y_row = X_row * (offset + W_row)
102
108
 
109
+ if casting_mode == _CASTING_MODE_GEMMA:
110
+ Y_row = Y_row.to(X_row_dtype)
111
+
103
112
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
104
113
 
105
114
 
@@ -109,14 +118,17 @@ def _rms_norm_backward_kernel(
109
118
  dY_row_stride,
110
119
  X_ptr,
111
120
  X_row_stride,
121
+ X_dtype: tl.constexpr,
112
122
  W_ptr,
113
123
  W_row_stride,
114
124
  RSTD_ptr,
115
125
  RSTD_row_stride,
116
126
  dW_ptr,
117
127
  dW_row_stride,
128
+ n_rows,
118
129
  n_cols,
119
130
  offset,
131
+ rows_per_program: tl.constexpr,
120
132
  casting_mode: tl.constexpr,
121
133
  BLOCK_SIZE: tl.constexpr,
122
134
  ):
@@ -125,54 +137,60 @@ def _rms_norm_backward_kernel(
125
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
126
138
  """
127
139
 
128
- row_idx = tl.program_id(0)
140
+ row_block_id = tl.program_id(0)
141
+ row_start = row_block_id * rows_per_program
142
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
130
144
  mask = col_offsets < n_cols
131
145
 
132
- dY_ptr += row_idx * dY_row_stride
133
- X_ptr += row_idx * X_row_stride
134
- RSTD_ptr += row_idx * RSTD_row_stride
135
- dW_ptr += row_idx * dW_row_stride
146
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
136
147
 
137
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
138
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
139
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
140
- original_x_dtype = X_row.dtype
141
-
142
- # Get cached rms
143
- rstd_row = tl.load(RSTD_ptr)
148
+ dY_ptr += row_start * dY_row_stride
149
+ X_ptr += row_start * X_row_stride
150
+ RSTD_ptr += row_start
144
151
 
152
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
145
153
  W_row = W_row + offset
146
154
 
147
- X_row = X_row.to(tl.float32)
155
+ for _ in range(row_start, row_end):
156
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
157
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
148
158
 
149
- # Different bacward graphs for different casting modes
150
- if casting_mode == _CASTING_MODE_LLAMA:
151
- m = (dY_row * W_row).to(tl.float32)
159
+ # Get cached rms
160
+ rstd_row = tl.load(RSTD_ptr)
152
161
 
153
- elif casting_mode == _CASTING_MODE_GEMMA:
154
- dY_row, W_row = (
155
- dY_row.to(tl.float32),
156
- W_row.to(tl.float32),
157
- )
162
+ X_row = X_row.to(tl.float32)
158
163
 
159
- m = dY_row * W_row
164
+ # Different bacward graphs for different casting modes
165
+ if casting_mode == _CASTING_MODE_LLAMA:
166
+ m = (dY_row * W_row).to(tl.float32)
160
167
 
161
- dX_row = rstd_row * m
168
+ elif casting_mode == _CASTING_MODE_GEMMA:
169
+ dY_row = dY_row.to(tl.float32)
170
+ m = dY_row * W_row
171
+ else:
172
+ m = dY_row * W_row
162
173
 
163
- dX_row += (rstd_row) * (
164
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
165
- )
174
+ dX_row = rstd_row * m
166
175
 
167
- # calculate the gradient of W
168
- if casting_mode == _CASTING_MODE_LLAMA:
169
- dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
170
- else:
171
- # here X_row is already in fp32 (see previous if block)
172
- dW_row = dY_row * (X_row * rstd_row)
176
+ dX_row += (rstd_row) * (
177
+ -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
178
+ )
179
+
180
+ # calculate the gradient of W
181
+ if casting_mode == _CASTING_MODE_LLAMA:
182
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
+ else:
184
+ # here X_row is already in fp32 (see previous if block)
185
+ dW_row += dY_row * (X_row * rstd_row)
173
186
 
174
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
175
- tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
187
+ tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
+
189
+ dY_ptr += dY_row_stride
190
+ X_ptr += X_row_stride
191
+ RSTD_ptr += RSTD_row_stride
192
+
193
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
176
194
 
177
195
 
178
196
  _str_to_casting_mode = {
@@ -238,31 +256,38 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
238
256
  dim = shape[-1]
239
257
  dY = dY.view(-1, dim)
240
258
  n_rows, n_cols = dY.shape
241
- dW = torch.empty_like(
242
- X,
243
- dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
244
- )
245
259
 
260
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
261
+ # fp32 for numerical stability especially.
262
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
263
+
264
+ if n_cols > BLOCK_SIZE:
265
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
266
+ rows_per_program = math.ceil(n_rows / sm_count)
267
+ grid = (sm_count,)
246
268
  # Here we use dY to store the value of dX to save memory
247
- _rms_norm_backward_kernel[(n_rows,)](
269
+ _rms_norm_backward_kernel[grid](
248
270
  dY,
249
271
  dY.stride(0),
250
272
  X,
251
273
  X.stride(0),
274
+ torch_to_triton_dtype[X.dtype],
252
275
  W,
253
276
  W.stride(0),
254
277
  RSTD,
255
278
  RSTD.stride(0),
256
- dW,
257
- dW.stride(0),
279
+ _dW,
280
+ _dW.stride(0),
281
+ n_rows,
258
282
  n_cols,
259
283
  offset,
284
+ rows_per_program,
260
285
  casting_mode,
261
286
  BLOCK_SIZE=BLOCK_SIZE,
262
287
  num_warps=num_warps,
263
288
  )
264
289
  dX = dY.view(*shape)
265
- dW = torch.sum(dW, dim=0).to(W.dtype)
290
+ dW = _dW.sum(dim=0).to(W.dtype)
266
291
  return dX, dW
267
292
 
268
293
 
@@ -14,7 +14,7 @@ def silu(x):
14
14
  def _swiglu_forward_kernel(
15
15
  a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
16
  ):
17
- program_id = tl.program_id(0).cast(tl.int64)
17
+ program_id = tl.program_id(0).to(tl.int64)
18
18
 
19
19
  # locate start index
20
20
  a_ptr += program_id * stride
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
35
35
  def _swiglu_backward_kernel(
36
36
  dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
37
  ):
38
- program_id = tl.program_id(0).cast(tl.int64)
38
+ program_id = tl.program_id(0).to(tl.int64)
39
39
 
40
40
  # locate start index
41
41
  dc_ptr += program_id * stride