liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +56 -44
  6. liger_kernel/ops/kl_div.py +258 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +50 -43
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +45 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +14 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +579 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.1.dist-info/METADATA +395 -0
  33. liger_kernel-0.3.1.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,258 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+
9
+
10
+ def get_num_warps(BLOCK_SIZE):
11
+ num_warps = 4
12
+ if BLOCK_SIZE >= 32768:
13
+ num_warps = 32
14
+ elif BLOCK_SIZE >= 8192:
15
+ num_warps = 16
16
+ elif BLOCK_SIZE >= 2048:
17
+ num_warps = 8
18
+
19
+ return num_warps
20
+
21
+
22
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
+
24
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
+
26
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
27
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
28
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
30
+
31
+ _str_to_reduction_mode = {
32
+ "none": _REDUCTION_MODE_NONE.value,
33
+ "sum": _REDUCTION_MODE_SUM.value,
34
+ "mean": _REDUCTION_MODE_MEAN.value,
35
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
36
+ }
37
+
38
+
39
+ @triton.jit
40
+ def _kldiv_kernel_forward(
41
+ y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
42
+ y_stride, # int, prediction stride
43
+ gt_ptr, # [B, S], ground truth ptr
44
+ gt_stride, # int, ground truth stride
45
+ loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
46
+ loss_stride, # int, output stride
47
+ n_cols, # int, number of columns in the input tensor
48
+ eps,
49
+ BLOCK_SIZE: tl.constexpr,
50
+ log_target: tl.constexpr = False,
51
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
52
+ ):
53
+ pid = tl.program_id(0).to(tl.int64)
54
+ y_ptr += pid * y_stride
55
+ gt_ptr += pid * gt_stride
56
+ loss_ptr += pid * loss_stride
57
+
58
+ base_offsets = tl.arange(0, BLOCK_SIZE)
59
+
60
+ loss_sum = 0.0
61
+ for i in range(0, n_cols, BLOCK_SIZE):
62
+ offsets = i + base_offsets
63
+ mask = offsets < n_cols
64
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
65
+ y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
66
+
67
+ # KL(y_true || y) = y_true * (log(y_true) - log(y))
68
+ # We compute KL(y_true || y) with y in the log-space
69
+ if not log_target:
70
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
71
+ else:
72
+ loss = tl.exp(y_true) * (y_true - y)
73
+
74
+ if reduction == _REDUCTION_MODE_NONE:
75
+ tl.store(loss_ptr + offsets, loss, mask=mask)
76
+ else:
77
+ loss_sum += tl.sum(loss, axis=0)
78
+
79
+ if reduction != _REDUCTION_MODE_NONE:
80
+ tl.store(loss_ptr, loss_sum)
81
+
82
+
83
+ @triton.jit
84
+ def _kldiv_kernel_backward(
85
+ target_ptr,
86
+ target_stride,
87
+ new_grads_ptr,
88
+ new_grads_stride,
89
+ n_cols,
90
+ BLOCK_SIZE: tl.constexpr,
91
+ log_target: tl.constexpr = False,
92
+ ):
93
+ pid = tl.program_id(0).to(tl.int64)
94
+
95
+ target_ptr += pid * target_stride
96
+ new_grads_ptr += pid * new_grads_stride
97
+
98
+ offsets = tl.arange(0, BLOCK_SIZE)
99
+ mask = offsets < n_cols
100
+
101
+ for i in range(0, n_cols, BLOCK_SIZE):
102
+ offsets = i + tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
106
+
107
+ if not log_target:
108
+ res = target * -1
109
+ else:
110
+ res = -tl.exp(target)
111
+
112
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
113
+
114
+
115
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
116
+ BT, V = y_pred.shape
117
+
118
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
119
+ num_warps = get_num_warps(BLOCK_SIZE)
120
+
121
+ grid = (BT,)
122
+ reduction = _str_to_reduction_mode[reduction]
123
+
124
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
125
+ output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
126
+
127
+ _kldiv_kernel_forward[grid](
128
+ y_pred,
129
+ y_pred.stride(0),
130
+ y_true,
131
+ y_true.stride(0),
132
+ output_tensor,
133
+ output_tensor.stride(0),
134
+ V,
135
+ eps=eps,
136
+ BLOCK_SIZE=BLOCK_SIZE,
137
+ num_warps=num_warps,
138
+ log_target=log_target,
139
+ reduction=reduction,
140
+ )
141
+
142
+ # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
143
+ # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
144
+ # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
145
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
146
+ return output_tensor.sum() / BT
147
+ elif reduction == _REDUCTION_MODE_SUM.value:
148
+ return output_tensor.sum(dim=0)
149
+ elif reduction == _REDUCTION_MODE_MEAN.value:
150
+ return output_tensor.sum() / (BT * V)
151
+ else:
152
+ return output_tensor
153
+
154
+
155
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
156
+ BT, V = target.shape
157
+
158
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
159
+ num_warps = get_num_warps(BLOCK_SIZE)
160
+
161
+ grid = (BT,)
162
+
163
+ # We store the gradients in-place in the input tensor
164
+ _kldiv_kernel_backward[grid](
165
+ target,
166
+ target.stride(0),
167
+ new_grads,
168
+ new_grads.stride(0),
169
+ V,
170
+ BLOCK_SIZE=BLOCK_SIZE,
171
+ num_warps=num_warps,
172
+ log_target=log_target,
173
+ )
174
+
175
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
176
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
177
+ return new_grads
178
+
179
+ return new_grads * grad_output
180
+
181
+
182
+ class LigerKLDivLossFunction(torch.autograd.Function):
183
+ """
184
+ Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
185
+ ```python
186
+ if log_target:
187
+ loss = target * (target.log() - input)
188
+ else:
189
+ loss = target.exp() * (target - input)
190
+ ```,
191
+ then the loss is reduced according to the `reduction` parameter.
192
+ as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
193
+ """
194
+
195
+ @staticmethod
196
+ @ensure_contiguous
197
+ def forward(
198
+ ctx,
199
+ y_pred: torch.Tensor,
200
+ y_true: torch.Tensor,
201
+ reduction: REDUCTION_LITERAL = "batchmean",
202
+ log_target: bool = False,
203
+ eps: float = 1e-10,
204
+ ) -> torch.Tensor:
205
+ """A forward pass for the KL Divergence Loss.
206
+
207
+ Args:
208
+ ctx: Torch autograd context
209
+ y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
210
+ y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
211
+ reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
212
+ log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
213
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
214
+
215
+ Returns:
216
+ torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
217
+ """
218
+ ctx.save_for_backward(y_true)
219
+ ctx.reduction = reduction
220
+ ctx.log_target = log_target
221
+ return kldiv_forward_triton(
222
+ y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
+ )
224
+
225
+ @staticmethod
226
+ @ensure_contiguous
227
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
228
+ """A backward pass for the KL Divergence Loss.
229
+
230
+ Args:
231
+ ctx: Torch autograd context
232
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
233
+
234
+ Returns:
235
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
236
+ """
237
+ (y_true,) = ctx.saved_tensors
238
+
239
+ new_grads = torch.empty_like(y_true)
240
+
241
+ derivative = kldiv_backward_triton(
242
+ y_true, grad_output, new_grads, ctx.log_target
243
+ )
244
+
245
+ if ctx.reduction == "batchmean":
246
+ derivative = derivative / y_true.shape[0]
247
+ elif ctx.reduction == "sum" or ctx.reduction == "none":
248
+ pass
249
+ elif ctx.reduction == "mean":
250
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
251
+
252
+ return (
253
+ derivative,
254
+ None,
255
+ None,
256
+ None,
257
+ None,
258
+ )
@@ -0,0 +1,236 @@
1
+ import math
2
+ import operator
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import (
9
+ calculate_settings,
10
+ compare_version,
11
+ ensure_contiguous,
12
+ )
13
+
14
+ if compare_version("triton", operator.ge, "3.0.0"):
15
+ try:
16
+ # typical import path with dispatch available
17
+ from triton.language.extra.libdevice import rsqrt
18
+ except ModuleNotFoundError:
19
+ # for working with NGC containers
20
+ from triton.language.extra.cuda.libdevice import rsqrt
21
+ else:
22
+ from triton.language.math import rsqrt
23
+
24
+
25
+ @triton.jit
26
+ def _layer_norm_forward_kernel(
27
+ Y_ptr, # pointer to output, shape (n_rows, n_cols)
28
+ Y_row_stride, # stride of each row in output
29
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
30
+ X_row_stride, # stride of each row in input
31
+ W_ptr, # pointer to weights, shape (n_cols,)
32
+ W_row_stride, # stride of each row in weights
33
+ B_ptr, # pointer to bias, shape (n_cols,)
34
+ B_row_stride, # stride of each row in bias
35
+ Mean_ptr, # pointer to mean, shape (n_rows,)
36
+ Mean_row_stride, # stride of each row in mean
37
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
38
+ RSTD_row_stride, # stride of each row in rstd
39
+ n_cols,
40
+ eps,
41
+ BLOCK_SIZE: tl.constexpr,
42
+ ):
43
+ """
44
+ References:
45
+ https://arxiv.org/abs/1607.06450
46
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
47
+ """
48
+ row_idx = tl.program_id(0)
49
+ col_offsets = tl.arange(0, BLOCK_SIZE)
50
+ mask = col_offsets < n_cols
51
+
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ Mean_ptr += row_idx * Mean_row_stride
55
+ RSTD_ptr += row_idx * RSTD_row_stride
56
+
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
58
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
59
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
60
+
61
+ mean = tl.sum(X_row, axis=0) / n_cols
62
+ var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
63
+ rstd = rsqrt(var + eps)
64
+
65
+ tl.store(Mean_ptr, mean)
66
+ tl.store(RSTD_ptr, rstd)
67
+
68
+ Y_row = (X_row - mean) * rstd * W_row + B_row
69
+
70
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
71
+
72
+
73
+ @triton.jit
74
+ def _layer_norm_backward_kernel(
75
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
76
+ W_ptr, # pointer to weights, shape (n_cols,)
77
+ Mean_ptr, # pointer to mean, shape (n_rows,)
78
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
79
+ DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
80
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
81
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
82
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
83
+ stride_x, # stride of each row in input
84
+ stride_dx, # stride of each row in input grad
85
+ stride_dw, # stride of each row in weights grad
86
+ stride_db, # stride of each row in bias grad
87
+ stride_dy, # stride of each row in output grad
88
+ n_rows,
89
+ n_cols,
90
+ rows_per_program: tl.constexpr,
91
+ BLOCK_SIZE: tl.constexpr,
92
+ dtype: tl.constexpr,
93
+ ):
94
+ """
95
+ References:
96
+ https://arxiv.org/abs/1607.06450
97
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
98
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
99
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
100
+ """
101
+ row_block_id = tl.program_id(0)
102
+ row_start = row_block_id * rows_per_program
103
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
104
+ cols = tl.arange(0, BLOCK_SIZE)
105
+ mask = cols < n_cols
106
+
107
+ dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
109
+
110
+ X_ptr += row_start * stride_x
111
+ Mean_ptr += row_start
112
+ RSTD_ptr += row_start
113
+ DX_ptr += row_start * stride_dx
114
+ DY_ptr += row_start * stride_dy
115
+
116
+ for _ in range(row_start, row_end):
117
+ x = tl.load(X_ptr + cols, mask=mask, other=0.0)
118
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
119
+ dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
120
+ mean = tl.load(Mean_ptr)
121
+ rstd = tl.load(RSTD_ptr)
122
+
123
+ x_hat = (x - mean) * rstd
124
+ wdy = w * dy
125
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
126
+ c2 = tl.sum(wdy, axis=0) / n_cols
127
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
128
+ tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
129
+
130
+ dw_row += dy * x_hat
131
+ db_row += dy
132
+
133
+ X_ptr += stride_x
134
+ Mean_ptr += 1
135
+ RSTD_ptr += 1
136
+ DX_ptr += stride_dx
137
+ DY_ptr += stride_dy
138
+
139
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
140
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
141
+
142
+
143
+ def layer_norm_forward(X, W, B, eps):
144
+ shape = X.shape
145
+ dim = shape[-1]
146
+ X = X.view(-1, dim)
147
+ n_rows, n_cols = X.shape
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
150
+ Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
151
+ RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
152
+ assert (
153
+ X.shape[1] == W.shape[0]
154
+ ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
155
+
156
+ _layer_norm_forward_kernel[(n_rows,)](
157
+ Y,
158
+ Y.stride(0),
159
+ X,
160
+ X.stride(0),
161
+ W,
162
+ W.stride(0),
163
+ B,
164
+ B.stride(0),
165
+ Mean,
166
+ Mean.stride(0),
167
+ RSTD,
168
+ RSTD.stride(0),
169
+ n_cols,
170
+ eps,
171
+ BLOCK_SIZE=BLOCK_SIZE,
172
+ num_warps=num_warps,
173
+ )
174
+ return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
175
+
176
+
177
+ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
178
+ shape = dY.shape
179
+ dim = shape[-1]
180
+ dY = dY.view(-1, dim)
181
+ n_rows, n_cols = dY.shape
182
+
183
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
184
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
185
+ _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
186
+ _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
187
+
188
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
189
+ if n_cols > BLOCK_SIZE:
190
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
191
+
192
+ rows_per_program = math.ceil(n_rows / sm_count)
193
+ grid = (sm_count,)
194
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
195
+ _layer_norm_backward_kernel[grid](
196
+ X,
197
+ W,
198
+ Mean,
199
+ RSTD,
200
+ DX,
201
+ _DW,
202
+ _DB,
203
+ dY,
204
+ X.stride(0),
205
+ DX.stride(0),
206
+ _DW.stride(0),
207
+ _DB.stride(0),
208
+ dY.stride(0),
209
+ n_rows,
210
+ n_cols,
211
+ rows_per_program,
212
+ BLOCK_SIZE=BLOCK_SIZE,
213
+ dtype=triton_dtype,
214
+ )
215
+
216
+ DW = _DW.sum(dim=0).to(W.dtype)
217
+ DB = _DB.sum(dim=0).to(W.dtype)
218
+
219
+ DX = DX.view(*shape)
220
+ return DX, DW, DB
221
+
222
+
223
+ class LigerLayerNormFunction(torch.autograd.Function):
224
+ @staticmethod
225
+ @ensure_contiguous
226
+ def forward(ctx, X, W, B, eps):
227
+ Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
228
+ ctx.save_for_backward(X, W, B, Mean, RSTD)
229
+ return Y
230
+
231
+ @staticmethod
232
+ @ensure_contiguous
233
+ def backward(ctx, dY):
234
+ X, W, B, Mean, RSTD = ctx.saved_tensors
235
+ DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
236
+ return DX, DW, DB, None