liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +56 -44
  6. liger_kernel/ops/kl_div.py +258 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +50 -43
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +45 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +14 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +579 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.1.dist-info/METADATA +395 -0
  33. liger_kernel-0.3.1.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,46 @@
1
+ import platform
2
+ import sys
3
+
4
+
5
+ def print_env_report():
6
+ """
7
+ Prints a report of the environment. Useful for debugging and reproducibility.
8
+ Usage:
9
+ ```
10
+ python -m liger_kernel.env_report
11
+ ```
12
+ """
13
+ print("Environment Report:")
14
+ print("-------------------")
15
+ print(f"Operating System: {platform.platform()}")
16
+ print(f"Python version: {sys.version.split()[0]}")
17
+
18
+ try:
19
+ import torch
20
+
21
+ print(f"PyTorch version: {torch.__version__}")
22
+ cuda_version = (
23
+ torch.version.cuda if torch.cuda.is_available() else "Not available"
24
+ )
25
+ print(f"CUDA version: {cuda_version}")
26
+ except ImportError:
27
+ print("PyTorch: Not installed")
28
+ print("CUDA version: Unable to query")
29
+
30
+ try:
31
+ import triton
32
+
33
+ print(f"Triton version: {triton.__version__}")
34
+ except ImportError:
35
+ print("Triton: Not installed")
36
+
37
+ try:
38
+ import transformers
39
+
40
+ print(f"Transformers version: {transformers.__version__}")
41
+ except ImportError:
42
+ print("Transformers: Not installed")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ print_env_report()
@@ -14,6 +14,8 @@ def liger_cross_entropy_kernel(
14
14
  n_cols,
15
15
  n_non_ignore,
16
16
  ignore_index,
17
+ label_smoothing: tl.constexpr,
18
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
17
19
  BLOCK_SIZE: tl.constexpr,
18
20
  ):
19
21
  """
@@ -30,6 +32,8 @@ def liger_cross_entropy_kernel(
30
32
  n_cols (int): The number of columns in the input tensor.
31
33
  n_non_ignore (int): The number of non-ignored elements in the batch.
32
34
  ignore_index (int): The index to ignore in the target.
35
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
36
+ reduction (str): The string for the reduction to apply
33
37
  BLOCK_SIZE (int): The block size for Triton operations.
34
38
  """
35
39
 
@@ -56,37 +60,62 @@ def liger_cross_entropy_kernel(
56
60
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
57
61
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
58
62
 
59
- # 3. [Oneline softmax] first pass: find max + sum
63
+ # 3. [Online softmax] first pass: find max + sum
60
64
  m = float("-inf") # m is the max value. use the notation from the paper
61
65
  d = 0.0 # d is the sum. use the notation from the paper
62
66
  ori_X_y = tl.load(
63
67
  X_ptr + y
64
68
  ) # we need to store the original value of X_y for the loss calculation
65
69
 
70
+ # Label smoothing is a general case of normal cross entropy
71
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
72
+ scaled_x_sum = 0.0
73
+ eps = label_smoothing / n_cols
74
+
66
75
  for i in range(0, n_cols, BLOCK_SIZE):
67
76
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
68
77
  X_block = tl.load(
69
78
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
70
79
  )
71
80
  block_max = tl.max(X_block)
81
+ if label_smoothing > 0:
82
+ # scale X beforehand to avoid overflow
83
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
72
84
  m_new = tl.maximum(m, block_max)
73
85
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
74
86
  m = m_new
75
87
 
76
- # 4. [Oneline softmax] second pass: calculate the gradients
88
+ # 4. [Online Softmax] Second pass: compute gradients
89
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
77
90
  # dx_y = (softmax(x_y) - 1) / N
78
91
  # dx_i = softmax(x_i) / N, i != y
79
- # N is the number of non ingored elements in the batch
92
+ # For label smoothing:
93
+ # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
94
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
95
+ # = dx_i - (1 - label_smoothing) / N
96
+ #
97
+ # For 'sum' reduction, no normalization is applied:
98
+ # dx_y = softmax(x_y) - 1
99
+ # dx_i = softmax(x_i), for i ≠ y
100
+ # For label smoothing:
101
+ # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
102
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
103
+ # = dx_i - (1 - label_smoothing)
104
+
80
105
  for i in range(0, n_cols, BLOCK_SIZE):
81
106
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
82
107
  X_block = tl.load(
83
108
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
84
109
  )
85
- X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
110
+ if reduction == "mean":
111
+ X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
112
+ else:
113
+ X_block = tl.exp(X_block - m) / d - eps
114
+
86
115
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
87
116
 
88
117
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
89
- # ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
118
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
90
119
  tl.debug_barrier()
91
120
 
92
121
  # 5. Calculate the loss
@@ -97,9 +126,28 @@ def liger_cross_entropy_kernel(
97
126
  # So we can safely calculate log (softmax(X_y)) without overflow
98
127
  loss = -(ori_X_y - m - tl.log(d))
99
128
 
100
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
129
+ # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
130
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
131
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
132
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
133
+ # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
134
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
135
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
136
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
137
+ if label_smoothing > 0:
138
+ smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
139
+ loss = loss * (1 - label_smoothing) + smooth_loss
140
+
141
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
142
+ if reduction == "mean":
143
+ loss = loss / n_non_ignore
144
+
145
+ # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
101
146
  X_y = tl.load(X_ptr + y)
102
- X_y += -1 / (n_non_ignore)
147
+ if reduction == "mean":
148
+ X_y += -(1 - label_smoothing) / (n_non_ignore)
149
+ else:
150
+ X_y += -(1 - label_smoothing)
103
151
 
104
152
  tl.store(loss_ptr, loss)
105
153
  tl.store(X_ptr + y, X_y)
@@ -112,7 +160,7 @@ MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
112
160
 
113
161
 
114
162
  @triton.jit
115
- def element_mul(
163
+ def element_mul_kernel(
116
164
  X_ptr,
117
165
  X_stride,
118
166
  grad_output_ptr,
@@ -147,6 +195,70 @@ def element_mul(
147
195
  tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
148
196
 
149
197
 
198
+ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
199
+ BT, V = _input.shape
200
+ n_rows = BT
201
+
202
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
203
+
204
+ # unreduced loss
205
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
206
+
207
+ n_non_ignore = (target != ignore_index).sum().item()
208
+
209
+ # ensure _input and target are contiguous in the last dimension
210
+ if _input.stride(-1) != 1:
211
+ _input = _input.contiguous()
212
+ if target.stride(-1) != 1:
213
+ target = target.contiguous()
214
+
215
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
216
+ liger_cross_entropy_kernel[(n_rows,)](
217
+ X_ptr=_input,
218
+ X_stride=_input.stride(-2),
219
+ Y_ptr=target,
220
+ Y_stride=target.stride(-1), # always 1
221
+ loss_ptr=loss_1d,
222
+ loss_stride=loss_1d.stride(-1), # always 1
223
+ n_cols=V,
224
+ n_non_ignore=n_non_ignore,
225
+ ignore_index=ignore_index,
226
+ label_smoothing=label_smoothing,
227
+ reduction=reduction,
228
+ BLOCK_SIZE=BLOCK_SIZE,
229
+ # TODO: 32 seems to give the best performance
230
+ # Performance is quite sensitive to num_warps
231
+ num_warps=32,
232
+ )
233
+
234
+ loss = torch.sum(loss_1d)
235
+ return loss, _input
236
+
237
+
238
+ def cross_entropy_backward(_input, grad_output):
239
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
240
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
241
+ pass
242
+
243
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
244
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
245
+ else:
246
+ BT, V = _input.shape
247
+ n_rows = BT
248
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
249
+
250
+ element_mul_kernel[(n_rows,)](
251
+ _input,
252
+ _input.stride(-2),
253
+ grad_output,
254
+ V,
255
+ BLOCK_SIZE=BLOCK_SIZE,
256
+ num_warps=32,
257
+ )
258
+
259
+ return _input
260
+
261
+
150
262
  class LigerCrossEntropyFunction(torch.autograd.Function):
151
263
  """
152
264
  This class implements a custom autograd function for the Liger Cross Entropy loss.
@@ -154,7 +266,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
154
266
  """
155
267
 
156
268
  @staticmethod
157
- def forward(ctx, _input, target, ignore_index):
269
+ def forward(
270
+ ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
271
+ ):
158
272
  """
159
273
  The forward pass of the Liger Cross Entropy loss.
160
274
 
@@ -163,45 +277,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
163
277
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
164
278
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
165
279
  ignore_index (int): The index to ignore in the target.
280
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
281
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
166
282
 
167
283
  Returns:
168
284
  tensor: The computed loss.
169
285
  """
170
- BT, V = _input.shape
171
- n_rows = BT
172
-
173
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
174
-
175
- # unreduced loss
176
- loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
177
-
178
- n_non_ignore = (target != ignore_index).sum().item()
179
-
180
- # ensure _input and target are contiguous in the last dimension
181
- if _input.stride(-1) != 1:
182
- _input = _input.contiguous()
183
- if target.stride(-1) != 1:
184
- target = target.contiguous()
185
-
186
- # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
187
- liger_cross_entropy_kernel[(n_rows,)](
188
- X_ptr=_input,
189
- X_stride=_input.stride(-2),
190
- Y_ptr=target,
191
- Y_stride=target.stride(-1), # always 1
192
- loss_ptr=loss_1d,
193
- loss_stride=loss_1d.stride(-1), # always 1
194
- n_cols=V,
195
- n_non_ignore=n_non_ignore,
196
- ignore_index=ignore_index,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- # TODO: 32 seems to give the best performance
199
- # Performance is quite sentitive to num_warps
200
- num_warps=32,
286
+ loss, _input = cross_entropy_forward(
287
+ _input, target, ignore_index, label_smoothing, reduction
201
288
  )
202
-
203
- loss = torch.sum(loss_1d) / n_non_ignore
204
-
205
289
  # TODO: investigation
206
290
  # If we don't detach the _input tensor, the memory will double
207
291
  # Not sure why but seems that there will be a time both grad and value exist but in different location
@@ -221,28 +305,11 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
221
305
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
222
306
  """
223
307
  (_input,) = ctx.saved_tensors
224
- # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
225
- if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
226
- pass
227
-
228
- # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
229
- # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
230
- else:
231
- BT, V = _input.shape
232
- n_rows = BT
233
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
234
-
235
- element_mul[(n_rows,)](
236
- _input,
237
- _input.stride(-2),
238
- grad_output,
239
- V,
240
- BLOCK_SIZE=BLOCK_SIZE,
241
- num_warps=32,
242
- )
243
-
308
+ _input = cross_entropy_backward(_input, grad_output)
244
309
  return (
245
310
  _input,
246
311
  None,
247
312
  None,
313
+ None,
314
+ None,
248
315
  )
@@ -0,0 +1,143 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import ensure_contiguous
6
+
7
+
8
+ @triton.jit
9
+ def embedding_forward_kernel(
10
+ embeddings_ptr,
11
+ indices_ptr,
12
+ output_ptr,
13
+ n_elements,
14
+ embedding_dim: tl.constexpr,
15
+ BLOCK_SIZE_M: tl.constexpr,
16
+ BLOCK_SIZE_N: tl.constexpr,
17
+ ):
18
+ pid_m = tl.program_id(0)
19
+ pid_n = tl.program_id(1)
20
+
21
+ start_m = pid_m * BLOCK_SIZE_M
22
+ start_n = pid_n * BLOCK_SIZE_N
23
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
24
+ mask_m = offsets_m < n_elements
25
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
26
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
27
+ mask_n = offsets_n < embedding_dim
28
+
29
+ embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
30
+ embeddings = tl.load(
31
+ embeddings_ptr + embedding_offsets,
32
+ mask=mask_m[:, None] & mask_n[None, :],
33
+ other=0.0,
34
+ )
35
+
36
+ output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37
+ tl.store(
38
+ output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
39
+ )
40
+
41
+
42
+ @triton.jit
43
+ def embedding_backward_kernel(
44
+ grad_output_ptr,
45
+ grad_weight_ptr,
46
+ indices_ptr,
47
+ n_elements,
48
+ embedding_dim: tl.constexpr,
49
+ BLOCK_SIZE_M: tl.constexpr,
50
+ BLOCK_SIZE_N: tl.constexpr,
51
+ ):
52
+ pid_m = tl.program_id(0)
53
+ pid_n = tl.program_id(1)
54
+
55
+ start_m = pid_m * BLOCK_SIZE_M
56
+ start_n = pid_n * BLOCK_SIZE_N
57
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
58
+ mask_m = offsets_m < n_elements
59
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
60
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
61
+ mask_n = offsets_n < embedding_dim
62
+
63
+ grad_output = tl.load(
64
+ grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
65
+ mask=mask_m[:, None] & mask_n[None, :],
66
+ other=0.0,
67
+ )
68
+
69
+ grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
70
+
71
+ tl.atomic_add(
72
+ grad_weight_ptr + grad_weight_offsets,
73
+ grad_output,
74
+ mask=mask_m[:, None] & mask_n[None, :],
75
+ )
76
+
77
+
78
+ class LigerEmbeddingFunction(torch.autograd.Function):
79
+ @staticmethod
80
+ @ensure_contiguous
81
+ def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
82
+ ori_shape = indices.shape
83
+ indices = indices.view(-1)
84
+ output = torch.empty(
85
+ indices.shape[0],
86
+ embeddings.shape[1],
87
+ device=indices.device,
88
+ dtype=embeddings.dtype,
89
+ )
90
+
91
+ n_elements = indices.numel()
92
+ embedding_dim = embeddings.shape[1]
93
+
94
+ BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
95
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
96
+ grid = (
97
+ triton.cdiv(n_elements, BLOCK_SIZE_M),
98
+ triton.cdiv(embedding_dim, BLOCK_SIZE_N),
99
+ )
100
+
101
+ embedding_forward_kernel[grid](
102
+ embeddings,
103
+ indices,
104
+ output,
105
+ n_elements,
106
+ embedding_dim=embedding_dim,
107
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
108
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
109
+ )
110
+
111
+ ctx.save_for_backward(indices, embeddings)
112
+
113
+ return output.view(*ori_shape, -1)
114
+
115
+ @staticmethod
116
+ @ensure_contiguous
117
+ def backward(ctx, grad_output: torch.Tensor):
118
+ indices, embedding_table = ctx.saved_tensors
119
+ grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
120
+
121
+ grad_weight = torch.zeros_like(embedding_table)
122
+
123
+ n_elements = indices.numel()
124
+ embedding_dim = embedding_table.shape[1]
125
+
126
+ BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
127
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
128
+ grid = (
129
+ triton.cdiv(n_elements, BLOCK_SIZE_M),
130
+ triton.cdiv(embedding_dim, BLOCK_SIZE_N),
131
+ )
132
+
133
+ embedding_backward_kernel[grid](
134
+ grad_output,
135
+ grad_weight,
136
+ indices,
137
+ n_elements,
138
+ embedding_dim=embedding_dim,
139
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
140
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
141
+ )
142
+
143
+ return grad_weight, None