liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+
6
+ from liger_kernel.ops.jsd import _jsd_kernel
7
+ from liger_kernel.ops.utils import (
8
+ amp_custom_bwd,
9
+ amp_custom_fwd,
10
+ element_mul_kernel,
11
+ is_hip,
12
+ )
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
16
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
17
+ MAX_FUSED_SIZE = 65536 // 2
18
+
19
+
20
+ def fused_linear_jsd_forward(
21
+ student_input,
22
+ student_weight,
23
+ teacher_input,
24
+ teacher_weight,
25
+ shift_labels,
26
+ jsd_beta,
27
+ ignore_index,
28
+ has_label,
29
+ temperature,
30
+ ):
31
+ device = student_input.device
32
+ dtype = student_input.dtype
33
+
34
+ # inputs have shape: BT x H
35
+ # materialized activations will have shape: BT x V
36
+ # the increase in memory = BT x V
37
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
38
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
39
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
40
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
41
+ BT, H = student_input.shape
42
+ V = student_weight.shape[0]
43
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
+
45
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
+ chunk_size = triton.next_power_of_2(
47
+ triton.cdiv(BT, inc_factor)
48
+ ) # (BT + inc_factor - 1) // inc_factor
49
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
+
51
+ grad_weight = (
52
+ torch.zeros_like(student_weight, device=device)
53
+ if student_weight.requires_grad
54
+ else None
55
+ )
56
+ grad_input = torch.zeros_like(student_input)
57
+ # we use fp32 for loss accumulator
58
+ loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
59
+
60
+ if has_label:
61
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
62
+ else:
63
+ n_non_ignore = BT
64
+
65
+ for chunk_id in range(num_chunks):
66
+ start_idx = chunk_id * chunk_size
67
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
68
+
69
+ # chunk both inputs, shape: chunk_size x H
70
+ student_input_chunk = student_input[start_idx:end_idx]
71
+ teacher_input_chunk = teacher_input[start_idx:end_idx]
72
+
73
+ # shape: chunk_size x V
74
+ # For anything starting from logits to the final JSD loss, we do computation
75
+ # in FP32 to avoid losing numerical stability.
76
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
+ torch.float32
78
+ )
79
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
+ torch.float32
81
+ )
82
+ chunk_n_rows = student_logits_chunk.shape[0]
83
+
84
+ # unreduced loss
85
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
86
+ # log-softmax with temperature
87
+ student_logits_chunk = student_logits_chunk / temperature
88
+ teacher_logits_chunk = teacher_logits_chunk / temperature
89
+ student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
90
+ teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
91
+
92
+ # ensure _input and target are contiguous
93
+ student_prob_chunk = student_prob_chunk.contiguous()
94
+ teacher_prob_chunk = teacher_prob_chunk.contiguous()
95
+
96
+ # Here we calculate the gradient of prob_chunk in place so we can save memory.
97
+ _jsd_kernel[(chunk_n_rows,)](
98
+ X_ptr=student_prob_chunk,
99
+ X_stride=student_prob_chunk.stride(-2),
100
+ Y_ptr=teacher_prob_chunk,
101
+ Y_stride=teacher_prob_chunk.stride(-2),
102
+ loss_ptr=loss_1d_slice,
103
+ loss_stride=loss_1d_slice.stride(-2),
104
+ dX_ptr=student_prob_chunk,
105
+ dX_stride=student_prob_chunk.stride(-2),
106
+ label_ptr=(
107
+ shift_labels[start_idx:end_idx]
108
+ if has_label
109
+ else torch.empty(1, device=device)
110
+ ), # dummy ptr if no label
111
+ beta=jsd_beta,
112
+ n_non_ignore=n_non_ignore,
113
+ ignore_index=ignore_index,
114
+ n_cols=V,
115
+ BLOCK_SIZE=BLOCK_SIZE,
116
+ HAS_LABEL=has_label,
117
+ )
118
+ loss_1d[start_idx:end_idx] = loss_1d_slice
119
+ # gradients of prob_chunk in place, shape: chunk_size x V
120
+ # gradients of logits_chunk in place, shape: chunk_size x V
121
+ student_logits_chunk = (
122
+ student_prob_chunk
123
+ - torch.softmax(student_logits_chunk, dim=-1)
124
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
+ student_prob_chunk.shape
126
+ )
127
+ ) / temperature
128
+ # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
+ # w.r.t. `lm_head` which should be computed in original dtype
130
+ student_logits_chunk = student_logits_chunk.to(dtype)
131
+ grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
132
+
133
+ if grad_weight is not None:
134
+ grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
135
+
136
+ loss = torch.sum(loss_1d)
137
+ return loss, grad_input, grad_weight
138
+
139
+
140
+ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
141
+ # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
142
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
143
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
144
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
145
+ BT, H = grad_input.shape
146
+ n_rows = BT
147
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
148
+
149
+ element_mul_kernel[(n_rows,)](
150
+ grad_input,
151
+ grad_input.stride(-2),
152
+ grad_output,
153
+ H,
154
+ BLOCK_SIZE=BLOCK_SIZE,
155
+ num_warps=32 if not is_hip() else 16,
156
+ )
157
+
158
+ # handle grad_weight
159
+ if grad_weight is not None:
160
+ V, H = grad_weight.shape
161
+ n_rows = V
162
+
163
+ element_mul_kernel[(n_rows,)](
164
+ grad_weight,
165
+ grad_weight.stride(-2),
166
+ grad_output,
167
+ H,
168
+ BLOCK_SIZE=BLOCK_SIZE,
169
+ num_warps=32 if not is_hip() else 16,
170
+ )
171
+
172
+ return grad_input, grad_weight
173
+
174
+
175
+ class LigerFusedLinearJSDFunction(torch.autograd.Function):
176
+ """
177
+ Fusing the last linear layer with generalized JSD
178
+
179
+ Handle the forward and backward pass of the final linear layer via JSD by avoiding
180
+ the materialization of the large logits tensor. Since JSD is the last layer, we can
181
+ compute the gradient at the forward pass.
182
+ """
183
+
184
+ @staticmethod
185
+ @amp_custom_fwd
186
+ def forward(
187
+ ctx,
188
+ student_input: torch.Tensor,
189
+ student_weight: torch.Tensor,
190
+ teacher_input: torch.Tensor,
191
+ teacher_weight: torch.Tensor,
192
+ shift_labels: Optional[torch.Tensor] = None,
193
+ jsd_beta: float = 0.5,
194
+ ignore_index: int = -100,
195
+ temperature: float = 1.0,
196
+ ):
197
+ """
198
+ Args:
199
+
200
+ student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
201
+ student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
202
+ teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
+ teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
+ jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
206
+ ignore_index (int): the index to ignore. Default: -100
207
+ temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
+
209
+ Returns:
210
+ loss (torch.Tensor): generalized JSD
211
+ """
212
+ has_label = False
213
+ if shift_labels is not None:
214
+ assert shift_labels.shape == (
215
+ teacher_input.shape[0],
216
+ ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
217
+ shift_labels = shift_labels.contiguous()
218
+ has_label = True
219
+
220
+ loss, grad_input, grad_weight = fused_linear_jsd_forward(
221
+ student_input,
222
+ student_weight,
223
+ teacher_input,
224
+ teacher_weight,
225
+ shift_labels,
226
+ jsd_beta,
227
+ ignore_index,
228
+ has_label,
229
+ temperature,
230
+ )
231
+ # downcast to dtype and store for backward
232
+ ctx.save_for_backward(
233
+ grad_input.detach(),
234
+ grad_weight.detach() if grad_weight is not None else None,
235
+ )
236
+ return loss
237
+
238
+ @staticmethod
239
+ @amp_custom_bwd
240
+ def backward(ctx, grad_output):
241
+ (grad_input, grad_weight) = ctx.saved_tensors
242
+ grad_input, grad_weight = fused_linear_jsd_backward(
243
+ grad_output, grad_input, grad_weight
244
+ )
245
+ return (grad_input, grad_weight, None, None, None, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -25,7 +25,7 @@ else:
25
25
  def _geglu_tanh_forward_kernel(
26
26
  a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
27
  ):
28
- program_id = tl.program_id(0).cast(tl.int64)
28
+ program_id = tl.program_id(0).to(tl.int64)
29
29
 
30
30
  # locate start index
31
31
  a += program_id * stride
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
52
52
  def _geglu_tanh_backward_kernel(
53
53
  dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
54
  ):
55
- program_id = tl.program_id(0).cast(tl.int64)
55
+ program_id = tl.program_id(0).to(tl.int64)
56
56
 
57
57
  # locate start index
58
58
  dc += program_id * stride
@@ -0,0 +1,322 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import compare_version, ensure_contiguous
8
+
9
+ if compare_version("triton", operator.ge, "3.0.0"):
10
+ try:
11
+ # typical import path with dispatch available
12
+ from triton.language.extra.libdevice import rsqrt
13
+ except ModuleNotFoundError:
14
+ # for working with NGC containers
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+ MAX_FUSED_SIZE = 65536
20
+
21
+
22
+ @triton.jit
23
+ def _group_norm_forward_kernel(
24
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
25
+ Y_row_stride, # stride of each row in output
26
+ Y_col_stride, # stride of each column in output
27
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
28
+ X_row_stride, # stride of each row in input
29
+ X_col_stride, # stride of each column in input
30
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
31
+ Mean_row_stride, # stride of each row in mean
32
+ Mean_col_stride, # stride of each column in mean
33
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
34
+ RSTD_row_stride, # stride of each row in rstd
35
+ RSTD_col_stride, # stride of each column in rstd
36
+ W_ptr, # pointer to W
37
+ B_ptr, # pointer to B
38
+ hidden_size, # hidden size of X
39
+ channels_per_group, # the number of channels per group
40
+ eps,
41
+ BLOCK_SIZE: tl.constexpr,
42
+ ):
43
+ """
44
+ References:
45
+ https://nn.labml.ai/normalization/group_norm/index.html
46
+ """
47
+ batch_idx = tl.program_id(0)
48
+ group_idx = tl.program_id(1)
49
+
50
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
51
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
52
+
53
+ block_range = tl.arange(0, BLOCK_SIZE)
54
+
55
+ # Compute mean and variance using the online algorithm
56
+ s = 0.0
57
+ squared_sum = 0.0
58
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
59
+ hidden_size_offsets = i + block_range
60
+ mask = hidden_size_offsets < hidden_size
61
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
62
+ s += tl.sum(X)
63
+ # X**2
64
+ squared_sum += tl.sum(X * X)
65
+
66
+ m = s / hidden_size
67
+
68
+ # variance = E[X**2] - E[X]**2
69
+ variance = (squared_sum / hidden_size) - (m * m)
70
+
71
+ # 1/std
72
+ rstd = rsqrt(variance + eps)
73
+
74
+ # Normalize
75
+ hidden_size_per_channel = hidden_size // channels_per_group
76
+ for channel_idx in tl.range(
77
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
78
+ ):
79
+ W = tl.load(W_ptr + channel_idx)
80
+ B = tl.load(B_ptr + channel_idx)
81
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
82
+ hidden_size_offsets = i + block_range
83
+ mask = hidden_size_offsets < hidden_size_per_channel
84
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
85
+ Y = (X - m) * rstd * W + B
86
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
87
+
88
+ X_ptr += hidden_size_per_channel
89
+ Y_ptr += hidden_size_per_channel
90
+
91
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
92
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
93
+
94
+
95
+ @triton.jit
96
+ def _group_norm_backward_kernel(
97
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
98
+ X_row_stride, # stride of each row in input
99
+ X_col_stride, # stride of each column in input
100
+ W_ptr, # pointer to weights, shape (n_channels)
101
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
102
+ Mean_ptr_row_stride, # stride of each column in mean
103
+ Mean_ptr_col_stride, # stride of each column in mean
104
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
105
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
106
+ DW_ptr, # pointer to weights grad, shape (n_channels)
107
+ DB_ptr, # pointer to bias grad, shape (n_channels)
108
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
109
+ hidden_size: tl.constexpr, # hidden size
110
+ channels_per_group: tl.constexpr, # number of groups in group norm
111
+ BLOCK_SIZE: tl.constexpr,
112
+ dtype: tl.constexpr,
113
+ ):
114
+ """
115
+ References:
116
+ https://nn.labml.ai/normalization/group_norm/index.html
117
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
118
+
119
+ The backprop equations are the same for group_norm and layer_norm
120
+ the only difference here is that we load the Mean, Rstd corresponding to the
121
+ group we're computing gradients for and the mean and rstd are computed over n-channels
122
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
123
+
124
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
125
+ """
126
+ batch_idx = tl.program_id(0)
127
+ group_idx = tl.program_id(1)
128
+
129
+ # Move the pointers to the correct batch
130
+ X_ptr += batch_idx * X_row_stride
131
+ DX_ptr += batch_idx * X_row_stride
132
+ UPSTREAM_ptr += batch_idx * X_row_stride
133
+
134
+ # Mean and rstd are the same shape so have the same strides
135
+ mean = tl.load(
136
+ Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
137
+ )
138
+ rstd = tl.load(
139
+ RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
140
+ )
141
+
142
+ c1 = 0.0
143
+ c2 = 0.0
144
+ block_range = tl.arange(0, BLOCK_SIZE)
145
+
146
+ # We need to compute the sum terms of the backprop equations across all channels in the group
147
+ for channel_idx in range(
148
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
149
+ ):
150
+ dW = 0.0
151
+ dB = 0.0
152
+ # Move the pointers to the correct channel
153
+ W = tl.load(W_ptr + channel_idx)
154
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
155
+ hidden_size_offsets = i + block_range
156
+ mask = hidden_size_offsets < hidden_size
157
+ X = tl.load(
158
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
159
+ mask=mask,
160
+ other=0.0,
161
+ )
162
+ UPSTREAM_grad = tl.load(
163
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
164
+ mask=mask,
165
+ other=0.0,
166
+ )
167
+
168
+ x_hat = (X - mean) * rstd
169
+ dW += tl.sum(UPSTREAM_grad * x_hat)
170
+ dB += tl.sum(UPSTREAM_grad)
171
+
172
+ wdy = W * UPSTREAM_grad
173
+ c1 += tl.sum(x_hat * wdy)
174
+ c2 += tl.sum(wdy)
175
+
176
+ # Need to ensure additions to the same channel are atomic
177
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
178
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
179
+
180
+ N = hidden_size * channels_per_group
181
+ c1 = c1 / N
182
+ c2 = c2 / N
183
+
184
+ for channel_idx in tl.range(
185
+ group_idx * channels_per_group, (group_idx + 1) * channels_per_group
186
+ ):
187
+ # Move the pointers to the correct channel
188
+ W = tl.load(W_ptr + channel_idx)
189
+ for i in range(0, hidden_size, BLOCK_SIZE):
190
+ hidden_size_offsets = i + block_range
191
+ mask = hidden_size_offsets < hidden_size
192
+ X = tl.load(
193
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
194
+ mask=mask,
195
+ other=0.0,
196
+ )
197
+ UPSTREAM_grad = tl.load(
198
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
199
+ mask=mask,
200
+ other=0.0,
201
+ )
202
+
203
+ x_hat = (X - mean) * rstd
204
+ wdy = W * UPSTREAM_grad
205
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
206
+ tl.store(
207
+ DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
208
+ )
209
+
210
+
211
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
212
+ shape = X.shape
213
+ batch_size = shape[0]
214
+ channels_per_group = num_channels // num_groups
215
+ # Reshape X so that the mean and std are computed across the groups
216
+ X = X.view(batch_size, num_groups, -1).contiguous()
217
+ hidden_size = X.shape[-1]
218
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
219
+ Y = torch.empty(
220
+ (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
221
+ )
222
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
223
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
224
+
225
+ _group_norm_forward_kernel[(batch_size, num_groups)](
226
+ Y,
227
+ Y.stride(0),
228
+ Y.stride(1),
229
+ X,
230
+ X.stride(0),
231
+ X.stride(1),
232
+ Mean,
233
+ Mean.stride(0),
234
+ Mean.stride(1),
235
+ RSTD,
236
+ RSTD.stride(0),
237
+ RSTD.stride(1),
238
+ W,
239
+ B,
240
+ hidden_size,
241
+ channels_per_group,
242
+ eps,
243
+ BLOCK_SIZE=BLOCK_SIZE,
244
+ )
245
+ # Return tensors in the original shape
246
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
247
+
248
+
249
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
250
+ shape = dY.shape
251
+ batch_size = shape[0]
252
+ hidden_size = dY.shape[-1]
253
+ channels_per_group = num_channels // num_groups
254
+ dY = dY.view(batch_size, num_groups, -1)
255
+ DX = torch.empty(
256
+ (batch_size, num_groups, hidden_size * channels_per_group),
257
+ dtype=X.dtype,
258
+ device=X.device,
259
+ )
260
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
261
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
262
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
263
+
264
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
265
+ _group_norm_backward_kernel[(batch_size, num_groups)](
266
+ X,
267
+ X.stride(0),
268
+ X.stride(1),
269
+ W,
270
+ Mean,
271
+ Mean.stride(0),
272
+ Mean.stride(1),
273
+ RSTD,
274
+ DX,
275
+ DW,
276
+ DB,
277
+ dY,
278
+ hidden_size,
279
+ channels_per_group,
280
+ BLOCK_SIZE=BLOCK_SIZE,
281
+ dtype=triton_dtype,
282
+ )
283
+
284
+ # Return tensors in the original shape
285
+ return DX.view(*shape), DW, DB
286
+
287
+
288
+ class LigerGroupNormFunction(torch.autograd.Function):
289
+ @staticmethod
290
+ @ensure_contiguous
291
+ def forward(
292
+ ctx,
293
+ X,
294
+ affine_scaling_weight,
295
+ affine_shifting_bias,
296
+ num_channels,
297
+ num_groups,
298
+ eps,
299
+ ):
300
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
301
+ X,
302
+ num_channels,
303
+ num_groups,
304
+ affine_scaling_weight,
305
+ affine_shifting_bias,
306
+ eps,
307
+ )
308
+ ctx.num_channels = num_channels
309
+ ctx.num_groups = num_groups
310
+ ctx.save_for_backward(
311
+ X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
312
+ )
313
+ return Y
314
+
315
+ @staticmethod
316
+ @ensure_contiguous
317
+ def backward(ctx, dY):
318
+ X, W, B, Mean, RSTD = ctx.saved_tensors
319
+ DX, DW, DB = group_norm_backward(
320
+ dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
321
+ )
322
+ return DX, DW, DB, None, None, None