liger-kernel 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
6
+
7
+
8
+ @triton.jit
9
+ def _rms_norm_forward(
10
+ Y_ptr,
11
+ Y_row_stride,
12
+ X_ptr,
13
+ X_row_stride,
14
+ W_ptr,
15
+ W_row_stride,
16
+ r_ptr,
17
+ r_row_stride,
18
+ n_cols,
19
+ eps,
20
+ BLOCK_SIZE: tl.constexpr,
21
+ ):
22
+ """
23
+ Reference:
24
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
25
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
26
+ """
27
+
28
+ row_idx = tl.program_id(0)
29
+ col_offsets = tl.arange(0, BLOCK_SIZE)
30
+ mask = col_offsets < n_cols
31
+
32
+ Y_ptr += row_idx * Y_row_stride
33
+ X_ptr += row_idx * X_row_stride
34
+ r_ptr += row_idx * r_row_stride
35
+
36
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
37
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
38
+
39
+ row_var = tl.sum(X_row * X_row, axis=0) / n_cols
40
+ inv_var = tl.math.rsqrt(row_var + eps)
41
+
42
+ # trick: row_var is tiny compared to X_row because it just has one per row we can save 4 ops (*, sum, /, rqrt) if we cache it
43
+ tl.store(r_ptr, inv_var)
44
+
45
+ normed = X_row * inv_var
46
+
47
+ output = normed * W_row
48
+ tl.store(Y_ptr + col_offsets, output, mask=mask)
49
+
50
+
51
+ @triton.jit
52
+ def _rms_norm_backward(
53
+ dY_ptr,
54
+ dY_row_stride,
55
+ X_ptr,
56
+ X_row_stride,
57
+ W_ptr,
58
+ W_row_stride,
59
+ r_ptr,
60
+ r_row_stride,
61
+ dW_ptr,
62
+ dW_row_stride,
63
+ n_cols,
64
+ eps,
65
+ BLOCK_SIZE: tl.constexpr,
66
+ ):
67
+ """
68
+ dx = (1 / var(x)) * (dy * w - (1/N) * (dy * w) dot x) * x
69
+ dw = sum(dy * (x / var(x)))
70
+ """
71
+ row_idx = tl.program_id(0)
72
+ col_offsets = tl.arange(0, BLOCK_SIZE)
73
+ mask = col_offsets < n_cols
74
+
75
+ dY_ptr += row_idx * dY_row_stride
76
+ X_ptr += row_idx * X_row_stride
77
+ r_ptr += row_idx * r_row_stride
78
+ dW_ptr += row_idx * dW_row_stride
79
+
80
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
81
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
82
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
83
+
84
+ # Get saved row variance
85
+ inv_var = tl.load(r_ptr)
86
+
87
+ normed = X_row * inv_var
88
+
89
+ dY_W = dY_row * W_row
90
+ dY_normed = dY_row * normed
91
+
92
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
93
+ output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
94
+ tl.store(dY_ptr + col_offsets, output, mask=mask)
95
+
96
+ # calculate the gradient of W
97
+ tl.store(dW_ptr + col_offsets, dY_normed, mask=mask)
98
+
99
+
100
+ class LigerRMSNormFunction(torch.autograd.Function):
101
+ @staticmethod
102
+ @ensure_contiguous
103
+ def forward(ctx, X, W, eps):
104
+ shape = X.shape
105
+ dim = shape[-1]
106
+ X = X.view(-1, dim)
107
+ n_rows, n_cols = X.shape
108
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
109
+
110
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda")
111
+ r = torch.empty(n_rows, dtype=X.dtype, device="cuda")
112
+
113
+ # Check constraints.
114
+ assert (
115
+ X.shape[1] == W.shape[0]
116
+ ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
117
+
118
+ _rms_norm_forward[(n_rows,)](
119
+ Y,
120
+ Y.stride(0),
121
+ X,
122
+ X.stride(0),
123
+ W,
124
+ W.stride(0),
125
+ r,
126
+ r.stride(0),
127
+ n_cols,
128
+ eps,
129
+ BLOCK_SIZE=BLOCK_SIZE,
130
+ num_warps=num_warps,
131
+ )
132
+ ctx.eps = eps
133
+ ctx.BLOCK_SIZE = BLOCK_SIZE
134
+ ctx.num_warps = num_warps
135
+
136
+ ctx.save_for_backward(X, W, r)
137
+ return Y.view(*shape)
138
+
139
+ @staticmethod
140
+ @ensure_contiguous
141
+ def backward(ctx, dY):
142
+ shape = dY.shape
143
+ dim = shape[-1]
144
+ dY = dY.view(-1, dim)
145
+ X, W, r = ctx.saved_tensors
146
+ n_rows, n_cols = dY.shape
147
+ dW = torch.zeros_like(X)
148
+
149
+ _rms_norm_backward[(n_rows,)](
150
+ dY,
151
+ dY.stride(0),
152
+ X,
153
+ X.stride(0),
154
+ W,
155
+ W.stride(0),
156
+ r,
157
+ r.stride(0),
158
+ dW,
159
+ dW.stride(0),
160
+ n_cols,
161
+ ctx.eps,
162
+ BLOCK_SIZE=ctx.BLOCK_SIZE,
163
+ num_warps=ctx.num_warps,
164
+ )
165
+ dX = dY.view(*shape)
166
+ dW = torch.sum(dW, dim=0)
167
+ return dX, dW, None
@@ -0,0 +1,234 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_rope(
8
+ q_ptr,
9
+ q_row_stride,
10
+ k_ptr,
11
+ k_row_stride,
12
+ cos,
13
+ cos_row_stride,
14
+ sin,
15
+ sin_row_stride,
16
+ bs: tl.constexpr,
17
+ sl: tl.constexpr,
18
+ n_qh: tl.constexpr,
19
+ n_kh: tl.constexpr,
20
+ hd: tl.constexpr,
21
+ pad_n_qh: tl.constexpr,
22
+ pad_n_kh: tl.constexpr,
23
+ pad_hd: tl.constexpr,
24
+ BLOCK_SIZE: tl.constexpr,
25
+ BACKWARD_PASS: tl.constexpr = False,
26
+ ):
27
+ # q size: (bsz, seq_len, num_q_heads, head_dim)
28
+ # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
29
+ # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
+ # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
+
32
+ # cos size: (1, seq_len, head_dim)
33
+ # stride: (seq_len * head_dim, head_dim, 1)
34
+ pid = tl.program_id(0)
35
+
36
+ # locate start address
37
+ q_ptr = q_ptr + pid * q_row_stride
38
+ k_ptr = k_ptr + pid * k_row_stride
39
+
40
+ # ####################################################################
41
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
42
+ # m of this program instance
43
+ # ####################################################################
44
+
45
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
46
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
47
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
48
+ # and pid % sl to get the sequence index.
49
+ # 2. We only need the left half of cos and sin matrix because the right half is just
50
+ # a clone of the left half.
51
+ cos_row_idx = pid % (sl)
52
+ cos = cos + cos_row_idx * cos_row_stride
53
+ sin = sin + cos_row_idx * sin_row_stride
54
+ cos_offsets = tl.arange(0, pad_hd // 2)
55
+ cos_mask = cos_offsets < hd // 2
56
+ cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
57
+ sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
58
+
59
+ # ####################################################################
60
+ # Load the left and right half of q and k for the current
61
+ # program instance (i.e. for the current token) separately
62
+ # ####################################################################
63
+ # left half of the head
64
+ first_half_q_offsets = (
65
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
+ )
67
+ first_half_k_offsets = (
68
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
+ )
70
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
+ )
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
+ )
76
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
+ sin_row.dtype
78
+ )
79
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
+ sin_row.dtype
81
+ )
82
+
83
+ # right half of the head
84
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
+ second_q_mask = first_q_mask
87
+ second_k_mask = first_k_mask
88
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
+ sin_row.dtype
90
+ )
91
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
+ sin_row.dtype
93
+ )
94
+
95
+ if not BACKWARD_PASS:
96
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
97
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
98
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
99
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
100
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
101
+
102
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
103
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
104
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
105
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
106
+ else:
107
+ # with some math, we can get:
108
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
109
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
110
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
111
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
112
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
113
+
114
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
115
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
116
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
117
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
118
+
119
+
120
+ class LigerRopeFunction(torch.autograd.Function):
121
+ """
122
+ Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
123
+ this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
124
+ than the original RoPE paper.
125
+
126
+ Please find the corresponding HuggingFace implementation here:
127
+ https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
128
+
129
+ For more details about the rotation matrix used here, please refer to:
130
+ https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
131
+ """
132
+
133
+ @staticmethod
134
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
135
+ """
136
+ q size: (bsz, n_q_head, seq_len, head_dim)
137
+ k size: (bsz, n_kv_head, seq_len, head_dim)
138
+ cos size: (1, seq_len, head_dim)
139
+ sin size: (1, seq_len, head_dim)
140
+ """
141
+
142
+ # transpose it back to the physical shape because Triton looks at the physical storage
143
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
144
+ q = q.transpose(1, 2)
145
+ k = k.transpose(1, 2)
146
+
147
+ batch_size, seq_len, n_q_head, head_dim = q.shape
148
+ n_kv_head = k.shape[2]
149
+ pad_hd = triton.next_power_of_2(head_dim)
150
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
151
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
152
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
153
+
154
+ n_row = batch_size * seq_len
155
+
156
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
157
+ q = q.contiguous()
158
+ k = k.contiguous()
159
+ cos = cos.contiguous()
160
+ sin = sin.contiguous()
161
+
162
+ _triton_rope[(n_row,)](
163
+ q,
164
+ q.stride(1),
165
+ k,
166
+ k.stride(1),
167
+ cos,
168
+ cos.stride(-2),
169
+ sin,
170
+ sin.stride(-2),
171
+ batch_size,
172
+ seq_len,
173
+ n_q_head,
174
+ n_kv_head,
175
+ head_dim,
176
+ pad_n_q_head,
177
+ pad_n_kv_head,
178
+ pad_hd,
179
+ BLOCK_SIZE=BLOCK_SIZE,
180
+ BACKWARD_PASS=False,
181
+ )
182
+
183
+ ctx.save_for_backward(cos, sin)
184
+ return q.transpose(1, 2), k.transpose(1, 2)
185
+
186
+ def backward(ctx, dq, dk):
187
+ """
188
+ dq size: (bsz, n_q_head, seq_len, head_dim)
189
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
190
+ cos size: (1, seq_len, head_dim)
191
+ sin size: (1, seq_len, head_dim)
192
+ """
193
+
194
+ cos, sin = ctx.saved_tensors
195
+
196
+ dq = dq.transpose(1, 2)
197
+ dk = dk.transpose(1, 2)
198
+
199
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
200
+ n_kv_head = dk.shape[2]
201
+ pad_hd = triton.next_power_of_2(head_dim)
202
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
203
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
204
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
205
+
206
+ n_row = batch_size * seq_len
207
+
208
+ # ensure dq and dk are contiguous
209
+ dq = dq.contiguous()
210
+ dk = dk.contiguous()
211
+
212
+ # backward is similar to forward except swapping few ops
213
+ _triton_rope[(n_row,)](
214
+ dq,
215
+ dq.stride(1),
216
+ dk,
217
+ dk.stride(1),
218
+ cos,
219
+ cos.stride(-2),
220
+ sin,
221
+ sin.stride(-2),
222
+ batch_size,
223
+ seq_len,
224
+ n_q_head,
225
+ n_kv_head,
226
+ head_dim,
227
+ pad_n_q_head,
228
+ pad_n_kv_head,
229
+ pad_hd,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ BACKWARD_PASS=True,
232
+ )
233
+
234
+ return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
@@ -0,0 +1,113 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
6
+
7
+
8
+ @triton.jit
9
+ def silu(x):
10
+ return x * tl.sigmoid(x)
11
+
12
+
13
+ @triton.jit
14
+ def _swiglu_forward_kernel(
15
+ a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
+ ):
17
+ program_id = tl.program_id(0)
18
+
19
+ # locate start index
20
+ a += program_id * stride
21
+ b += program_id * stride
22
+ c += program_id * stride
23
+
24
+ col_offsets = tl.arange(0, BLOCK_SIZE)
25
+ mask = col_offsets < n_cols
26
+
27
+ # sigmoid requires type float32
28
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
29
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
30
+ c_row = silu(a_row) * b_row
31
+ tl.store(c + col_offsets, c_row, mask=mask)
32
+
33
+
34
+ @triton.jit
35
+ def _swiglu_backward_kernel(
36
+ dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
+ ):
38
+ program_id = tl.program_id(0)
39
+
40
+ # locate start index
41
+ dc += program_id * stride
42
+ a += program_id * stride
43
+ b += program_id * stride
44
+
45
+ col_offsets = tl.arange(0, BLOCK_SIZE)
46
+ mask = col_offsets < n_cols
47
+
48
+ dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
49
+ # sigmoid requires type float32
50
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
51
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
52
+
53
+ # recomputation to save memory
54
+ sig_a = tl.sigmoid(a_row)
55
+ silu_a = a_row * sig_a
56
+ db_row = dc_row * silu_a
57
+ da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
58
+
59
+ tl.store(a + col_offsets, da_row, mask=mask)
60
+ tl.store(b + col_offsets, db_row, mask=mask)
61
+
62
+
63
+ class LigerSiLUMulFunction(torch.autograd.Function):
64
+ @staticmethod
65
+ @ensure_contiguous
66
+ def forward(ctx, a, b):
67
+ ori_shape = a.shape
68
+
69
+ n_cols = ori_shape[-1]
70
+ a = a.view(-1, n_cols)
71
+ b = b.view(-1, n_cols)
72
+ c = torch.zeros_like(a)
73
+ n_rows = a.shape[0]
74
+
75
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
76
+
77
+ _swiglu_forward_kernel[(n_rows,)](
78
+ a,
79
+ b,
80
+ c,
81
+ c.stride(-2),
82
+ n_cols=n_cols,
83
+ BLOCK_SIZE=BLOCK_SIZE,
84
+ num_warps=num_warps,
85
+ )
86
+
87
+ ctx.save_for_backward(a, b)
88
+
89
+ return c.view(*ori_shape)
90
+
91
+ @staticmethod
92
+ @ensure_contiguous
93
+ def backward(ctx, dc):
94
+
95
+ ori_shape = dc.shape
96
+ n_cols = ori_shape[-1]
97
+ dc = dc.view(-1, n_cols)
98
+ a, b = ctx.saved_tensors
99
+ n_rows = dc.shape[0]
100
+
101
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
102
+
103
+ _swiglu_backward_kernel[(n_rows,)](
104
+ dc,
105
+ a,
106
+ b,
107
+ dc.stride(-2),
108
+ n_cols=n_cols,
109
+ BLOCK_SIZE=BLOCK_SIZE,
110
+ num_warps=num_warps,
111
+ )
112
+
113
+ return a.view(*ori_shape), b.view(*ori_shape)
@@ -0,0 +1,38 @@
1
+ import functools
2
+
3
+ import torch
4
+ import triton
5
+
6
+
7
+ def ensure_contiguous(fn):
8
+ @functools.wraps(fn)
9
+ def wrapper(ctx, *args, **kwargs):
10
+ def maybe_to_contiguous(x):
11
+ return x.contiguous() if isinstance(x, torch.Tensor) else x
12
+
13
+ args = [maybe_to_contiguous(arg) for arg in args]
14
+ kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
15
+ return fn(ctx, *args, **kwargs)
16
+
17
+ return wrapper
18
+
19
+
20
+ def calculate_settings(n):
21
+ # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
22
+
23
+ MAX_FUSED_SIZE = 65536
24
+ BLOCK_SIZE = triton.next_power_of_2(n)
25
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
26
+ raise RuntimeError(
27
+ f"Cannot launch Triton kernel since n = {n} exceeds "
28
+ f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
29
+ )
30
+
31
+ num_warps = 4
32
+ if BLOCK_SIZE >= 32768:
33
+ num_warps = 32
34
+ elif BLOCK_SIZE >= 8192:
35
+ num_warps = 16
36
+ elif BLOCK_SIZE >= 2048:
37
+ num_warps = 8
38
+ return BLOCK_SIZE, num_warps
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
2
+ apply_liger_kernel_to_llama,
3
+ apply_liger_kernel_to_mistral,
4
+ apply_liger_kernel_to_mixtral,
5
+ )
@@ -0,0 +1,11 @@
1
+ from torch.nn import CrossEntropyLoss
2
+
3
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+
5
+
6
+ class LigerCrossEntropyLoss(CrossEntropyLoss):
7
+ def __init__(self, *args, **kwargs):
8
+ super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
+
10
+ def forward(self, _input, target):
11
+ return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
@@ -0,0 +1,15 @@
1
+ from torch.nn import CrossEntropyLoss
2
+
3
+ from liger_kernel.ops.fused_linear_cross_entropy import (
4
+ LigerFusedLinearCrossEntropyFunction,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
+ def __init__(self, *args, **kwargs):
10
+ super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
11
+
12
+ def forward(self, lin_weight, _input, target):
13
+ return LigerFusedLinearCrossEntropyFunction.apply(
14
+ _input, lin_weight, target, self.ignore_index
15
+ )
@@ -0,0 +1,23 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
4
+
5
+
6
+ class LigerGEGLUMLP(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.hidden_size = config.hidden_size
11
+ self.intermediate_size = config.intermediate_size
12
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
13
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15
+ # TODO: support exact GELU
16
+ if config.hidden_act not in ["gelu_pytorch_tanh"]:
17
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
18
+
19
+ def forward(self, x):
20
+
21
+ return self.down_proj(
22
+ LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
23
+ )
File without changes