liger-kernel 0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. liger_kernel-0.0.0/PKG-INFO +4 -0
  2. liger_kernel-0.0.0/setup.cfg +4 -0
  3. liger_kernel-0.0.0/setup.py +26 -0
  4. liger_kernel-0.0.0/src/liger_kernel/ops/__init__.py +0 -0
  5. liger_kernel-0.0.0/src/liger_kernel/ops/cross_entropy.py +277 -0
  6. liger_kernel-0.0.0/src/liger_kernel/ops/fused_linear_cross_entropy.py +161 -0
  7. liger_kernel-0.0.0/src/liger_kernel/ops/geglu.py +129 -0
  8. liger_kernel-0.0.0/src/liger_kernel/ops/rms_norm.py +167 -0
  9. liger_kernel-0.0.0/src/liger_kernel/ops/rope.py +234 -0
  10. liger_kernel-0.0.0/src/liger_kernel/ops/swiglu.py +113 -0
  11. liger_kernel-0.0.0/src/liger_kernel/ops/utils.py +38 -0
  12. liger_kernel-0.0.0/src/liger_kernel/transformers/__init__.py +5 -0
  13. liger_kernel-0.0.0/src/liger_kernel/transformers/cross_entropy.py +11 -0
  14. liger_kernel-0.0.0/src/liger_kernel/transformers/fused_linear_cross_entropy.py +15 -0
  15. liger_kernel-0.0.0/src/liger_kernel/transformers/geglu.py +23 -0
  16. liger_kernel-0.0.0/src/liger_kernel/transformers/model/__init__.py +0 -0
  17. liger_kernel-0.0.0/src/liger_kernel/transformers/model/llama.py +143 -0
  18. liger_kernel-0.0.0/src/liger_kernel/transformers/monkey_patch.py +103 -0
  19. liger_kernel-0.0.0/src/liger_kernel/transformers/rms_norm.py +16 -0
  20. liger_kernel-0.0.0/src/liger_kernel/transformers/rope.py +20 -0
  21. liger_kernel-0.0.0/src/liger_kernel/transformers/swiglu.py +40 -0
  22. liger_kernel-0.0.0/src/liger_kernel/triton/__init__.py +3 -0
  23. liger_kernel-0.0.0/src/liger_kernel/triton/monkey_patch.py +44 -0
  24. liger_kernel-0.0.0/src/liger_kernel.egg-info/PKG-INFO +4 -0
  25. liger_kernel-0.0.0/src/liger_kernel.egg-info/SOURCES.txt +26 -0
  26. liger_kernel-0.0.0/src/liger_kernel.egg-info/dependency_links.txt +1 -0
  27. liger_kernel-0.0.0/src/liger_kernel.egg-info/requires.txt +11 -0
  28. liger_kernel-0.0.0/src/liger_kernel.egg-info/top_level.txt +1 -0
@@ -0,0 +1,4 @@
1
+ Metadata-Version: 2.1
2
+ Name: liger_kernel
3
+ Version: 0.0.0
4
+ Provides-Extra: dev
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,26 @@
1
+ from setuptools import find_namespace_packages, setup
2
+
3
+ __version__ = "0.0.0"
4
+
5
+ setup(
6
+ name="liger_kernel",
7
+ version=__version__,
8
+ package_dir={"": "src"},
9
+ packages=find_namespace_packages(where="src"),
10
+ include_package_data=True,
11
+ install_requires=[
12
+ "torch>=2.1.2",
13
+ "triton>=2.3.0",
14
+ "transformers>=4.40.1",
15
+ ],
16
+ extras_require={
17
+ "dev": [
18
+ "matplotlib>=3.7.2",
19
+ "flake8>=4.0.1.1",
20
+ "black>=24.4.2",
21
+ "isort>=5.13.2",
22
+ "pre-commit>=3.7.1",
23
+ "torch-tb-profiler>=0.4.1",
24
+ ]
25
+ },
26
+ )
File without changes
@@ -0,0 +1,277 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def liger_cross_entropy_kernel(
8
+ X_ptr,
9
+ X_stride,
10
+ Y_ptr,
11
+ Y_stride,
12
+ loss_ptr,
13
+ loss_stride,
14
+ n_cols,
15
+ n_non_ignore,
16
+ ignore_index,
17
+ BLOCK_SIZE: tl.constexpr,
18
+ ):
19
+ """
20
+ This kernel computes both cross entropy loss and the gradient of the _input.
21
+ We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
22
+
23
+ Parameters:
24
+ X_ptr: Pointer to input tensor.
25
+ X_stride (int): The stride of the input tensor.
26
+ Y_ptr: Pointer to target tensor.
27
+ Y_stride (int): The stride of the target tensor.
28
+ loss_ptr: Pointer to tensor to store the loss.
29
+ loss_stride (int): The stride of the loss tensor.
30
+ n_cols (int): The number of columns in the input tensor.
31
+ n_non_ignore (int): The number of non-ignored elements in the batch.
32
+ ignore_index (int): The index to ignore in the target.
33
+ BLOCK_SIZE (int): The block size for Triton operations.
34
+ """
35
+
36
+ # https://github.com/triton-lang/triton/issues/1058
37
+ # Essentially if B*T*V is too large, program_id * stride will overflow out of int32
38
+ program_id = tl.program_id(0).to(tl.int64)
39
+
40
+ # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
41
+ Y_ptr += program_id * Y_stride
42
+ y = tl.load(Y_ptr)
43
+
44
+ # 2. locate the start index
45
+ X_ptr += program_id * X_stride
46
+
47
+ if y == ignore_index:
48
+ # set all X_ptr as 0
49
+ for i in range(0, n_cols, BLOCK_SIZE):
50
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
51
+ tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
52
+ return
53
+
54
+ loss_ptr += program_id * loss_stride
55
+
56
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
57
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
58
+
59
+ # 3. [Oneline softmax] first pass: find max + sum
60
+ m = float("-inf") # m is the max value. use the notation from the paper
61
+ d = 0.0 # d is the sum. use the notation from the paper
62
+ ori_X_y = tl.load(
63
+ X_ptr + y
64
+ ) # we need to store the original value of X_y for the loss calculation
65
+
66
+ for i in range(0, n_cols, BLOCK_SIZE):
67
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
68
+ X_block = tl.load(
69
+ X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
70
+ )
71
+ block_max = tl.max(X_block)
72
+ m_new = tl.maximum(m, block_max)
73
+ d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
74
+ m = m_new
75
+
76
+ # 4. [Oneline softmax] second pass: calculate the gradients
77
+ # dx_y = (softmax(x_y) - 1) / N
78
+ # dx_i = softmax(x_i) / N, i != y
79
+ # N is the number of non ingored elements in the batch
80
+ for i in range(0, n_cols, BLOCK_SIZE):
81
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
82
+ X_block = tl.load(
83
+ X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
84
+ )
85
+ X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
86
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
87
+
88
+ # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
89
+ # ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
90
+ tl.debug_barrier()
91
+
92
+ # 5. Calculate the loss
93
+ # Old Approach: Problematic LogSoftmax
94
+ # min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
95
+ # This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
96
+ # loss = -tl.log(X_y * n_non_ignore)
97
+
98
+ # New Approach: Safe LogSoftmax
99
+ # Therefore, we propose to use safe logsoftmax by reordering the formula.
100
+ # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
101
+ # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
102
+ # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
103
+ # So we can safely calculate log (softmax(X_y)) without overflow
104
+ loss = -(ori_X_y - m - tl.log(d))
105
+
106
+ # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
107
+ X_y = tl.load(X_ptr + y)
108
+ X_y += -1 / (n_non_ignore)
109
+
110
+ tl.store(loss_ptr, loss)
111
+ tl.store(X_ptr + y, X_y)
112
+
113
+
114
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
115
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
116
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
117
+ MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
118
+
119
+
120
+ @triton.jit
121
+ def element_mul(
122
+ X_ptr,
123
+ X_stride,
124
+ grad_output_ptr,
125
+ n_cols,
126
+ BLOCK_SIZE: tl.constexpr,
127
+ ):
128
+ """
129
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
130
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
131
+
132
+ Parameters:
133
+ X_ptr: Pointer to the input tensor.
134
+ X_stride (int): The stride of the input tensor.
135
+ grad_output_ptr: Pointer to the gradient output value.
136
+ n_cols (int): The number of columns in the input tensor.
137
+ BLOCK_SIZE (int): The block size for Triton operations.
138
+ """
139
+
140
+ # Get the program ID and convert it to int64 to avoid overflow
141
+ program_id = tl.program_id(0).to(tl.int64)
142
+
143
+ # Locate the start index
144
+ X_ptr += program_id * X_stride
145
+
146
+ # Load the gradient output value
147
+ grad_output = tl.load(grad_output_ptr)
148
+
149
+ # Perform the element-wise multiplication
150
+ for i in range(0, n_cols, BLOCK_SIZE):
151
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
152
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
153
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
154
+
155
+
156
+ class LigerCrossEntropyFunction(torch.autograd.Function):
157
+ """
158
+ This class implements a custom autograd function for the Liger Cross Entropy loss.
159
+ It overrides the forward and backward methods of the torch.autograd.Function class.
160
+ """
161
+
162
+ @staticmethod
163
+ def forward(ctx, _input, target, ignore_index):
164
+ """
165
+ The forward pass of the Liger Cross Entropy loss.
166
+
167
+ Parameters:
168
+ ctx : The context object.
169
+ _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
170
+ target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
171
+ ignore_index (int): The index to ignore in the target.
172
+
173
+ Returns:
174
+ tensor: The computed loss.
175
+ """
176
+ BT, V = _input.shape
177
+ n_rows = BT
178
+
179
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
180
+
181
+ # unreduced loss
182
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
183
+
184
+ n_non_ignore = (target != ignore_index).sum().item()
185
+
186
+ # ensure _input and target are contiguous in the last dimension
187
+ # there are examples that are NOT contiguous overall but contiguous in the last dimension
188
+ ####################################################################
189
+ # tensor = torch.arange(1, 21).reshape(5, -1)
190
+ # print(tensor)
191
+ # tensor([[ 1, 2, 3, 4],
192
+ # [ 5, 6, 7, 8],
193
+ # [ 9, 10, 11, 12],
194
+ # [13, 14, 15, 16],
195
+ # [17, 18, 19, 20]])
196
+ # print(tensor.is_contiguous())
197
+ # True
198
+ # slice = tensor[::2, :]
199
+ # print(slice)
200
+ # tensor([[ 1, 2, 3, 4],
201
+ # [ 9, 10, 11, 12],
202
+ # [17, 18, 19, 20]])
203
+ # print(slice.is_contiguous())
204
+ # False
205
+ # print(slice.stride())
206
+ # (8, 1)
207
+ # slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
208
+ ####################################################################
209
+ if _input.stride(-1) != 1:
210
+ _input = _input.contiguous()
211
+ if target.stride(-1) != 1:
212
+ target = target.contiguous()
213
+
214
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
215
+ liger_cross_entropy_kernel[(n_rows,)](
216
+ X_ptr=_input,
217
+ X_stride=_input.stride(-2),
218
+ Y_ptr=target,
219
+ Y_stride=target.stride(-1), # always 1
220
+ loss_ptr=loss_1d,
221
+ loss_stride=loss_1d.stride(-1), # always 1
222
+ n_cols=V,
223
+ n_non_ignore=n_non_ignore,
224
+ ignore_index=ignore_index,
225
+ BLOCK_SIZE=BLOCK_SIZE,
226
+ # TODO: 32 seems to give the best performance
227
+ # Performance is quite sentitive to num_warps
228
+ num_warps=32,
229
+ )
230
+
231
+ loss = torch.sum(loss_1d) / n_non_ignore
232
+
233
+ # TODO: investigation
234
+ # If we don't detach the _input tensor, the memory will double
235
+ # Not sure why but seems that there will be a time both grad and value exist but in different location
236
+ ctx.save_for_backward(_input.detach())
237
+ return loss
238
+
239
+ @staticmethod
240
+ def backward(ctx, grad_output):
241
+ """
242
+ The backward pass of the Liger Cross Entropy loss.
243
+
244
+ Parameters:
245
+ ctx : The context object with saved tensors.
246
+ grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
247
+
248
+ Returns:
249
+ tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
250
+ """
251
+ (_input,) = ctx.saved_tensors
252
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
253
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
254
+ pass
255
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
256
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
257
+ # Although the Brew trainer should only perform backward once, it encounters this issue.
258
+ # https://github.com/triton-lang/triton/issues/4004
259
+ else:
260
+ BT, V = _input.shape
261
+ n_rows = BT
262
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
263
+
264
+ element_mul[(n_rows,)](
265
+ _input,
266
+ _input.stride(-2),
267
+ grad_output,
268
+ V,
269
+ BLOCK_SIZE=BLOCK_SIZE,
270
+ num_warps=32,
271
+ )
272
+
273
+ return (
274
+ _input,
275
+ None,
276
+ None,
277
+ )
@@ -0,0 +1,161 @@
1
+ """Fusing the last linear layer with cross-entropy loss
2
+
3
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
4
+ """
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kernel
10
+
11
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
12
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
13
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
14
+ MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
15
+
16
+
17
+ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, _input, linear, target, ignore_index):
20
+ """
21
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
22
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
23
+ compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
24
+ for the backward pass.
25
+
26
+ _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
27
+ target: (B*T) where each value is in [0, V-1]
28
+ linear: linear projection matrix of shape V x H.
29
+ ignore_index: the index to ignore in the target
30
+ """
31
+ dtype = (
32
+ torch.get_autocast_gpu_dtype()
33
+ if torch.is_autocast_enabled()
34
+ else _input.dtype
35
+ )
36
+ device = _input.device
37
+
38
+ # inputs have shape: BT x H
39
+ # materialized activations will have shape: BT x V
40
+ # the increase in memory = BT x V
41
+ # reduction can be achieved by paritioning the number of tokens BT into smaller chunks.
42
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
43
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
44
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
45
+ BT, H = _input.shape
46
+ V = linear.shape[0]
47
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
48
+
49
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
50
+ chunk_size = triton.next_power_of_2(
51
+ triton.cdiv(BT, inc_factor)
52
+ ) # (BT + inc_factor - 1) // inc_factor
53
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
54
+
55
+ grad_linear = torch.zeros_like(linear, device=device)
56
+ grad_input = torch.zeros_like(_input, device=device)
57
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
58
+
59
+ total_n_non_ignore = (target != ignore_index).sum().item()
60
+
61
+ for chunk_id in range(num_chunks):
62
+ start_idx = chunk_id * chunk_size
63
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
64
+ _input_chunk = _input[start_idx:end_idx] # chunk_size x H
65
+
66
+ # when doing matmul, use the original precision
67
+ logits_chunk = _input_chunk @ linear.t() # chunk_size x V
68
+ target_chunk = target[start_idx:end_idx] # chunk_size,
69
+
70
+ n_rows = logits_chunk.shape[0]
71
+
72
+ # unreduced loss
73
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
74
+ n_non_ignore = (target_chunk != ignore_index).sum().item()
75
+
76
+ # when doing CE, use the upcasted precision
77
+ logits_chunk = logits_chunk.float()
78
+
79
+ # ensure _input and target are contiguous
80
+ logits_chunk = logits_chunk.contiguous()
81
+ target_chunk = target_chunk.contiguous()
82
+
83
+ # Here we calculate the gradient of logits_chunk in place so we can save memory.
84
+ liger_cross_entropy_kernel[(n_rows,)](
85
+ X_ptr=logits_chunk,
86
+ X_stride=logits_chunk.stride(-2),
87
+ Y_ptr=target_chunk,
88
+ Y_stride=target_chunk.stride(-1), # always 1
89
+ loss_ptr=loss_1d_slice,
90
+ loss_stride=loss_1d_slice.stride(-1), # always 1
91
+ n_cols=V,
92
+ n_non_ignore=n_non_ignore,
93
+ ignore_index=ignore_index,
94
+ BLOCK_SIZE=BLOCK_SIZE,
95
+ num_warps=32,
96
+ )
97
+
98
+ # gradient of logits_chunk is computed inplace by the above triton kernel.
99
+ # Following HuggingFace model source code, we do the forward and backward
100
+ # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
101
+ # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
+ # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
+ logits_chunk = logits_chunk.to(dtype)
104
+
105
+ # gradient of logits_chunk is computed inplace by the above triton kernel and is of shape: chunk_size x V
106
+ # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
107
+ # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
108
+ # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
109
+ # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
110
+ grad_logits_chunk = logits_chunk * (n_non_ignore / total_n_non_ignore)
111
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ linear
112
+
113
+ torch.addmm(
114
+ input=grad_linear,
115
+ mat1=logits_chunk.t(),
116
+ mat2=_input_chunk,
117
+ out=grad_linear,
118
+ alpha=n_non_ignore / total_n_non_ignore,
119
+ beta=1.0,
120
+ )
121
+
122
+ loss = torch.sum(loss_1d) / total_n_non_ignore
123
+
124
+ # downcast to dtype and store for backward
125
+ ctx.save_for_backward(grad_input.detach(), grad_linear.detach())
126
+ return loss
127
+
128
+ @staticmethod
129
+ def backward(ctx, grad_output):
130
+ (grad_input, grad_linear) = ctx.saved_tensors
131
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
132
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
133
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
134
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
135
+ BT, H = grad_input.shape
136
+ n_rows = BT
137
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
138
+
139
+ element_mul[(n_rows,)](
140
+ grad_input,
141
+ grad_input.stride(-2),
142
+ grad_output,
143
+ H,
144
+ BLOCK_SIZE=BLOCK_SIZE,
145
+ num_warps=32,
146
+ )
147
+
148
+ # handle grad_linear
149
+ V, H = grad_linear.shape
150
+ n_rows = V
151
+
152
+ element_mul[(n_rows,)](
153
+ grad_linear,
154
+ grad_linear.stride(-2),
155
+ grad_output,
156
+ H,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=32,
159
+ )
160
+
161
+ return (grad_input, grad_linear, None, None)
@@ -0,0 +1,129 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
6
+
7
+
8
+ @triton.jit
9
+ def _geglu_tanh_forward_kernel(
10
+ a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
11
+ ):
12
+ program_id = tl.program_id(0)
13
+
14
+ # locate start index
15
+ a += program_id * stride
16
+ b += program_id * stride
17
+ c += program_id * stride
18
+
19
+ col_offsets = tl.arange(0, BLOCK_SIZE)
20
+ mask = col_offsets < n_cols
21
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
22
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
23
+
24
+ # tanh approximation form of GELU is computed with:
25
+ # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
26
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
27
+ a_cubed = a_row * a_row * a_row
28
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
29
+ tanh_result = tl.math.tanh(tanh_arg)
30
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
31
+ c_row = geglu_a * b_row
32
+ tl.store(c + col_offsets, c_row, mask=mask)
33
+
34
+
35
+ @triton.jit
36
+ def _geglu_tanh_backward_kernel(
37
+ dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
38
+ ):
39
+ program_id = tl.program_id(0)
40
+
41
+ # locate start index
42
+ dc += program_id * stride
43
+ a += program_id * stride
44
+ b += program_id * stride
45
+
46
+ col_offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = col_offsets < n_cols
48
+
49
+ dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
50
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
51
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
52
+
53
+ # recomputation to save memory
54
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
55
+ a_cubed = a_row * a_row * a_row
56
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
57
+ tanh_result = tl.math.tanh(tanh_arg)
58
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
59
+
60
+ db_row = dc_row * geglu_a
61
+
62
+ # Gradient w.r.t. a can be computed with:
63
+ # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
64
+ # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
65
+ term1 = 0.5 * (1 + tanh_result)
66
+ tanh_sq = tanh_result * tanh_result
67
+ term2 = (
68
+ 0.5
69
+ * a_row
70
+ * (1 - tanh_sq)
71
+ * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
72
+ )
73
+ da_row = dc_row * b_row * (term1 + term2)
74
+
75
+ tl.store(a + col_offsets, da_row, mask=mask)
76
+ tl.store(b + col_offsets, db_row, mask=mask)
77
+
78
+
79
+ class LigerGELUMulFunction(torch.autograd.Function):
80
+ @staticmethod
81
+ @ensure_contiguous
82
+ def forward(ctx, a, b):
83
+ ori_shape = a.shape
84
+
85
+ n_cols = ori_shape[-1]
86
+ a = a.view(-1, n_cols)
87
+ b = b.view(-1, n_cols)
88
+ c = torch.zeros_like(a)
89
+ n_rows = a.shape[0]
90
+
91
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
92
+
93
+ _geglu_tanh_forward_kernel[(n_rows,)](
94
+ a,
95
+ b,
96
+ c,
97
+ c.stride(-2),
98
+ n_cols=n_cols,
99
+ BLOCK_SIZE=BLOCK_SIZE,
100
+ num_warps=num_warps,
101
+ )
102
+
103
+ ctx.save_for_backward(a, b)
104
+
105
+ return c.view(*ori_shape)
106
+
107
+ @staticmethod
108
+ @ensure_contiguous
109
+ def backward(ctx, dc):
110
+
111
+ ori_shape = dc.shape
112
+ n_cols = ori_shape[-1]
113
+ dc = dc.view(-1, n_cols)
114
+ a, b = ctx.saved_tensors
115
+ n_rows = dc.shape[0]
116
+
117
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
118
+
119
+ _geglu_tanh_backward_kernel[(n_rows,)](
120
+ dc,
121
+ a,
122
+ b,
123
+ dc.stride(-2),
124
+ n_cols=n_cols,
125
+ BLOCK_SIZE=BLOCK_SIZE,
126
+ num_warps=num_warps,
127
+ )
128
+
129
+ return a.view(*ori_shape), b.view(*ori_shape)