liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +54 -42
  6. liger_kernel/ops/kl_div.py +247 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +48 -41
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +33 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +13 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +605 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.0.dist-info/METADATA +388 -0
  33. liger_kernel-0.3.0.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,10 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kernel
4
+ from liger_kernel.ops.cross_entropy import (
5
+ element_mul_kernel,
6
+ liger_cross_entropy_kernel,
7
+ )
5
8
 
6
9
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
7
10
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -9,153 +12,227 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
9
12
  MAX_FUSED_SIZE = 65536 // 2
10
13
 
11
14
 
12
- class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
13
- @staticmethod
14
- def forward(ctx, _input, linear, target, ignore_index):
15
- """
16
- Fusing the last linear layer with cross-entropy loss
17
- Reference: https://github.com/mgmalek/efficient_cross_entropy
15
+ def fused_linear_cross_entropy_forward(
16
+ _input,
17
+ weight,
18
+ target,
19
+ bias=None,
20
+ ignore_index=-100,
21
+ label_smoothing=0.0,
22
+ reduction="mean",
23
+ ):
24
+ dtype = (
25
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
26
+ )
27
+ device = _input.device
28
+
29
+ # inputs have shape: BT x H
30
+ # materialized activations will have shape: BT x V
31
+ # the increase in memory = BT x V
32
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
33
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
34
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
35
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
36
+ BT, H = _input.shape
37
+ V = weight.shape[0]
38
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
39
+
40
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
41
+ chunk_size = triton.next_power_of_2(
42
+ triton.cdiv(BT, inc_factor)
43
+ ) # (BT + inc_factor - 1) // inc_factor
44
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
45
+
46
+ grad_weight = (
47
+ torch.zeros_like(weight, device=device) if weight.requires_grad else None
48
+ )
49
+ grad_input = torch.zeros_like(_input, device=device)
50
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
51
+ # we use fp32 for loss accumulator
52
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
53
+
54
+ # NOTE: skip .item() here to avoid CUDA synchronization
55
+ total_n_non_ignore = (target != ignore_index).sum()
56
+
57
+ for chunk_id in range(num_chunks):
58
+ start_idx = chunk_id * chunk_size
59
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
60
+ _input_chunk = _input[start_idx:end_idx] # chunk_size x H
61
+
62
+ # when doing matmul, use the original precision
63
+ logits_chunk = _input_chunk @ weight.t() # chunk_size x V
64
+ if bias is not None:
65
+ logits_chunk = logits_chunk + bias
66
+ target_chunk = target[start_idx:end_idx] # chunk_size,
67
+
68
+ n_rows = logits_chunk.shape[0]
69
+
70
+ # unreduced loss
71
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
72
+ n_non_ignore = (target_chunk != ignore_index).sum().item()
73
+
74
+ # when doing CE, use the upcasted precision
75
+ logits_chunk = logits_chunk.float()
76
+
77
+ # ensure _input and target are contiguous
78
+ logits_chunk = logits_chunk.contiguous()
79
+ target_chunk = target_chunk.contiguous()
80
+
81
+ # Here we calculate the gradient of logits_chunk in place so we can save memory.
82
+ liger_cross_entropy_kernel[(n_rows,)](
83
+ X_ptr=logits_chunk,
84
+ X_stride=logits_chunk.stride(-2),
85
+ Y_ptr=target_chunk,
86
+ Y_stride=target_chunk.stride(-1), # always 1
87
+ loss_ptr=loss_1d_slice,
88
+ loss_stride=loss_1d_slice.stride(-1), # always 1
89
+ n_cols=V,
90
+ n_non_ignore=n_non_ignore,
91
+ ignore_index=ignore_index,
92
+ label_smoothing=label_smoothing,
93
+ reduction=reduction,
94
+ BLOCK_SIZE=BLOCK_SIZE,
95
+ num_warps=32,
96
+ )
18
97
 
19
- Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
20
- the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
21
- compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
22
- for the backward pass.
98
+ # gradient of logits_chunk is computed in-place by the above triton kernel.
99
+ # Following HuggingFace model source code, we do the forward and backward
100
+ # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
101
+ # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
+ # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
+ logits_chunk = logits_chunk.to(dtype)
23
104
 
24
- _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
25
- target: (B*T) where each value is in [0, V-1]
26
- linear: linear projection matrix of shape V x H.
27
- ignore_index: the index to ignore in the target
28
- """
29
- dtype = (
30
- torch.get_autocast_gpu_dtype()
31
- if torch.is_autocast_enabled()
32
- else _input.dtype
33
- )
34
- device = _input.device
35
-
36
- # inputs have shape: BT x H
37
- # materialized activations will have shape: BT x V
38
- # the increase in memory = BT x V
39
- # reduction can be achieved by paritioning the number of tokens BT into smaller chunks.
40
- # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
41
- # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
42
- # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
43
- BT, H = _input.shape
44
- V = linear.shape[0]
45
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
46
-
47
- inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
48
- chunk_size = triton.next_power_of_2(
49
- triton.cdiv(BT, inc_factor)
50
- ) # (BT + inc_factor - 1) // inc_factor
51
- num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
52
-
53
- grad_linear = torch.zeros_like(linear, device=device)
54
- grad_input = torch.zeros_like(_input, device=device)
55
-
56
- # we use fp32 for loss accumulator
57
- loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
58
-
59
- total_n_non_ignore = (target != ignore_index).sum().item()
60
-
61
- for chunk_id in range(num_chunks):
62
- start_idx = chunk_id * chunk_size
63
- end_idx = min((chunk_id + 1) * chunk_size, BT)
64
- _input_chunk = _input[start_idx:end_idx] # chunk_size x H
65
-
66
- # when doing matmul, use the original precision
67
- logits_chunk = _input_chunk @ linear.t() # chunk_size x V
68
- target_chunk = target[start_idx:end_idx] # chunk_size,
69
-
70
- n_rows = logits_chunk.shape[0]
71
-
72
- # unreduced loss
73
- loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
74
- n_non_ignore = (target_chunk != ignore_index).sum().item()
75
-
76
- # when doing CE, use the upcasted precision
77
- logits_chunk = logits_chunk.float()
78
-
79
- # ensure _input and target are contiguous
80
- logits_chunk = logits_chunk.contiguous()
81
- target_chunk = target_chunk.contiguous()
82
-
83
- # Here we calculate the gradient of logits_chunk in place so we can save memory.
84
- liger_cross_entropy_kernel[(n_rows,)](
85
- X_ptr=logits_chunk,
86
- X_stride=logits_chunk.stride(-2),
87
- Y_ptr=target_chunk,
88
- Y_stride=target_chunk.stride(-1), # always 1
89
- loss_ptr=loss_1d_slice,
90
- loss_stride=loss_1d_slice.stride(-1), # always 1
91
- n_cols=V,
92
- n_non_ignore=n_non_ignore,
93
- ignore_index=ignore_index,
94
- BLOCK_SIZE=BLOCK_SIZE,
95
- num_warps=32,
96
- )
105
+ # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
106
+ # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
107
+ # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
108
+ # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
109
+ # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
97
110
 
98
- # gradient of logits_chunk is computed inplace by the above triton kernel.
99
- # Following HuggingFace model source code, we do the forward and backward
100
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
101
- # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
- # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
- logits_chunk = logits_chunk.to(dtype)
104
-
105
- # gradient of logits_chunk is computed inplace by the above triton kernel and is of shape: chunk_size x V
106
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
107
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
108
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
109
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
110
- grad_logits_chunk = logits_chunk * (n_non_ignore / total_n_non_ignore)
111
- grad_input[start_idx:end_idx] = grad_logits_chunk @ linear
111
+ if reduction == "mean":
112
+ alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
113
+ else:
114
+ alpha = 1.0
112
115
 
116
+ loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
117
+ grad_logits_chunk = logits_chunk * alpha # chunk_size x V
118
+
119
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
120
+
121
+ if grad_weight is not None:
113
122
  torch.addmm(
114
- input=grad_linear,
123
+ input=grad_weight,
115
124
  mat1=logits_chunk.t(),
116
125
  mat2=_input_chunk,
117
- out=grad_linear,
118
- alpha=n_non_ignore / total_n_non_ignore,
126
+ out=grad_weight,
127
+ alpha=alpha,
119
128
  beta=1.0,
120
129
  )
121
130
 
122
- loss = torch.sum(loss_1d) / total_n_non_ignore
131
+ if bias is not None:
132
+ torch.add(
133
+ input=grad_bias,
134
+ other=logits_chunk.sum(dim=0),
135
+ out=grad_bias,
136
+ alpha=alpha,
137
+ )
123
138
 
124
- # downcast to dtype and store for backward
125
- ctx.save_for_backward(grad_input.detach(), grad_linear.detach())
126
- return loss
139
+ loss = torch.sum(loss_1d)
140
+ return loss, grad_input, grad_weight, grad_bias
141
+
142
+
143
+ def fused_linear_cross_entropy_backward(
144
+ grad_output, grad_input, grad_weight, grad_bias
145
+ ):
146
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
147
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
148
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
149
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
150
+ BT, H = grad_input.shape
151
+ n_rows = BT
152
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
153
+
154
+ element_mul_kernel[(n_rows,)](
155
+ grad_input,
156
+ grad_input.stride(-2),
157
+ grad_output,
158
+ H,
159
+ BLOCK_SIZE=BLOCK_SIZE,
160
+ num_warps=32,
161
+ )
127
162
 
128
- @staticmethod
129
- def backward(ctx, grad_output):
130
- (grad_input, grad_linear) = ctx.saved_tensors
131
- # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
132
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
133
- # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
134
- # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
135
- BT, H = grad_input.shape
136
- n_rows = BT
137
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
138
-
139
- element_mul[(n_rows,)](
140
- grad_input,
141
- grad_input.stride(-2),
163
+ # handle grad_weight
164
+ if grad_weight is not None:
165
+ V, H = grad_weight.shape
166
+ n_rows = V
167
+
168
+ element_mul_kernel[(n_rows,)](
169
+ grad_weight,
170
+ grad_weight.stride(-2),
142
171
  grad_output,
143
172
  H,
144
173
  BLOCK_SIZE=BLOCK_SIZE,
145
174
  num_warps=32,
146
175
  )
147
176
 
148
- # handle grad_linear
149
- V, H = grad_linear.shape
177
+ if grad_bias is not None:
178
+ V = grad_bias.shape[0]
150
179
  n_rows = V
151
180
 
152
- element_mul[(n_rows,)](
153
- grad_linear,
154
- grad_linear.stride(-2),
181
+ element_mul_kernel[(n_rows,)](
182
+ grad_bias,
183
+ grad_bias.stride(-1),
155
184
  grad_output,
156
- H,
185
+ 1,
157
186
  BLOCK_SIZE=BLOCK_SIZE,
158
187
  num_warps=32,
159
188
  )
189
+ return grad_input, grad_weight, grad_bias
190
+
191
+
192
+ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
193
+ @staticmethod
194
+ def forward(
195
+ ctx,
196
+ _input,
197
+ weight,
198
+ target,
199
+ bias=None,
200
+ ignore_index=-100,
201
+ label_smoothing=0.0,
202
+ reduction="mean",
203
+ ):
204
+ """
205
+ Fusing the last linear layer with cross-entropy loss
206
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
207
+
208
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
209
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
210
+ compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
211
+ for the backward pass.
160
212
 
161
- return (grad_input, grad_linear, None, None)
213
+ _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
214
+ target: (B*T) where each value is in [0, V-1]
215
+ weight: (V, H) where V is the number of classes
216
+ bias: (V) where V is the number of classes
217
+ ignore_index: the index to ignore in the target
218
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
219
+ reduction: reduction to apply
220
+ """
221
+ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
222
+ _input, weight, target, bias, ignore_index, label_smoothing, reduction
223
+ )
224
+ # downcast to dtype and store for backward
225
+ ctx.save_for_backward(
226
+ grad_input.detach(),
227
+ grad_weight.detach() if grad_weight is not None else None,
228
+ grad_bias.detach() if bias is not None else None,
229
+ )
230
+ return loss
231
+
232
+ @staticmethod
233
+ def backward(ctx, grad_output):
234
+ (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
235
+ grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
236
+ grad_output, grad_input, grad_weight, grad_bias
237
+ )
238
+ return (grad_input, grad_weight, None, grad_bias, None, None, None)
liger_kernel/ops/geglu.py CHANGED
@@ -11,7 +11,12 @@ from liger_kernel.ops.utils import (
11
11
  )
12
12
 
13
13
  if compare_version("triton", operator.ge, "3.0.0"):
14
- from triton.language.extra.libdevice import tanh
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import tanh
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import tanh
15
20
  else:
16
21
  from triton.language.math import tanh
17
22
 
@@ -87,54 +92,61 @@ def _geglu_tanh_backward_kernel(
87
92
  tl.store(b + col_offsets, db_row, mask=mask)
88
93
 
89
94
 
95
+ def geglu_forward(a, b):
96
+ ori_shape = a.shape
97
+
98
+ n_cols = ori_shape[-1]
99
+ a = a.view(-1, n_cols)
100
+ b = b.view(-1, n_cols)
101
+ c = torch.empty_like(a)
102
+ n_rows = a.shape[0]
103
+
104
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
105
+
106
+ _geglu_tanh_forward_kernel[(n_rows,)](
107
+ a,
108
+ b,
109
+ c,
110
+ c.stride(-2),
111
+ n_cols=n_cols,
112
+ BLOCK_SIZE=BLOCK_SIZE,
113
+ num_warps=num_warps,
114
+ )
115
+ return a, b, c.view(*ori_shape)
116
+
117
+
118
+ def geglu_backward(a, b, dc):
119
+ ori_shape = dc.shape
120
+ n_cols = ori_shape[-1]
121
+ dc = dc.view(-1, n_cols)
122
+ n_rows = dc.shape[0]
123
+
124
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
125
+
126
+ _geglu_tanh_backward_kernel[(n_rows,)](
127
+ dc,
128
+ a,
129
+ b,
130
+ dc.stride(-2),
131
+ n_cols=n_cols,
132
+ BLOCK_SIZE=BLOCK_SIZE,
133
+ num_warps=num_warps,
134
+ )
135
+
136
+ return a.view(*ori_shape), b.view(*ori_shape)
137
+
138
+
90
139
  class LigerGELUMulFunction(torch.autograd.Function):
91
140
  @staticmethod
92
141
  @ensure_contiguous
93
142
  def forward(ctx, a, b):
94
- ori_shape = a.shape
95
-
96
- n_cols = ori_shape[-1]
97
- a = a.view(-1, n_cols)
98
- b = b.view(-1, n_cols)
99
- c = torch.zeros_like(a)
100
- n_rows = a.shape[0]
101
-
102
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
103
-
104
- _geglu_tanh_forward_kernel[(n_rows,)](
105
- a,
106
- b,
107
- c,
108
- c.stride(-2),
109
- n_cols=n_cols,
110
- BLOCK_SIZE=BLOCK_SIZE,
111
- num_warps=num_warps,
112
- )
113
-
143
+ a, b, c = geglu_forward(a, b)
114
144
  ctx.save_for_backward(a, b)
115
-
116
- return c.view(*ori_shape)
145
+ return c
117
146
 
118
147
  @staticmethod
119
148
  @ensure_contiguous
120
149
  def backward(ctx, dc):
121
-
122
- ori_shape = dc.shape
123
- n_cols = ori_shape[-1]
124
- dc = dc.view(-1, n_cols)
125
150
  a, b = ctx.saved_tensors
126
- n_rows = dc.shape[0]
127
-
128
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
129
-
130
- _geglu_tanh_backward_kernel[(n_rows,)](
131
- dc,
132
- a,
133
- b,
134
- dc.stride(-2),
135
- n_cols=n_cols,
136
- BLOCK_SIZE=BLOCK_SIZE,
137
- num_warps=num_warps,
138
- )
139
-
140
- return a.view(*ori_shape), b.view(*ori_shape)
151
+ a, b = geglu_backward(a, b, dc)
152
+ return a, b