liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +144 -65
- liger_kernel/ops/experimental/mm_int8int2.py +355 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
- liger_kernel/ops/fused_linear_jsd.py +245 -0
- liger_kernel/ops/geglu.py +2 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/jsd.py +176 -0
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/ops/rms_norm.py +92 -46
- liger_kernel/ops/swiglu.py +2 -2
- liger_kernel/ops/utils.py +62 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +38 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/fused_linear_jsd.py +98 -0
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/jsd.py +75 -0
- liger_kernel/transformers/model/gemma.py +124 -1
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/model/llama.py +135 -4
- liger_kernel/transformers/model/mistral.py +3 -0
- liger_kernel/transformers/model/mixtral.py +153 -2
- liger_kernel/transformers/model/mllama.py +274 -0
- liger_kernel/transformers/model/phi3.py +140 -2
- liger_kernel/transformers/model/qwen2.py +123 -2
- liger_kernel/transformers/model/qwen2_vl.py +8 -1
- liger_kernel/transformers/monkey_patch.py +258 -68
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
- liger_kernel-0.4.1.dist-info/NOTICE +58 -0
- liger_kernel-0.4.1.dist-info/RECORD +51 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
- liger_kernel-0.3.1.dist-info/NOTICE +0 -4
- liger_kernel-0.3.1.dist-info/RECORD +0 -42
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
|
|
7
|
+
values_per_item = 8 // bits
|
|
8
|
+
packed_shape = packed.shape
|
|
9
|
+
|
|
10
|
+
if len(packed_shape) == 1:
|
|
11
|
+
original_row_dim = packed_shape[0] * values_per_item
|
|
12
|
+
unpacked_shape = (original_row_dim,)
|
|
13
|
+
else:
|
|
14
|
+
original_row_dim = packed_shape[0] * values_per_item
|
|
15
|
+
unpacked_shape = (original_row_dim, *packed_shape[1:])
|
|
16
|
+
|
|
17
|
+
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
|
|
18
|
+
|
|
19
|
+
for i in range(values_per_item):
|
|
20
|
+
start = i * packed_shape[0]
|
|
21
|
+
end = start + packed_shape[0]
|
|
22
|
+
mask = 3 << (2 * i)
|
|
23
|
+
unpacked[start:end] = (packed & mask) >> (2 * i)
|
|
24
|
+
|
|
25
|
+
unpacked = unpacked.to(torch.int32) - 1
|
|
26
|
+
return unpacked
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
|
|
30
|
+
intweights += 1
|
|
31
|
+
original_shape = intweights.shape
|
|
32
|
+
values_per_item = 8 // bits
|
|
33
|
+
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
|
|
34
|
+
|
|
35
|
+
if len(original_shape) == 1:
|
|
36
|
+
packed_tensor_shape = (row_dim,)
|
|
37
|
+
else:
|
|
38
|
+
packed_tensor_shape = (row_dim, *original_shape[1:])
|
|
39
|
+
|
|
40
|
+
packed = torch.zeros(
|
|
41
|
+
packed_tensor_shape, device=intweights.device, dtype=torch.uint8
|
|
42
|
+
)
|
|
43
|
+
unpacked = intweights.to(torch.uint8)
|
|
44
|
+
|
|
45
|
+
def lshift(t: torch.Tensor, bits: int):
|
|
46
|
+
return t << bits
|
|
47
|
+
|
|
48
|
+
it = min(values_per_item, (original_shape[0] // row_dim) + 1)
|
|
49
|
+
for i in range(it):
|
|
50
|
+
start = i * row_dim
|
|
51
|
+
end = min(start + row_dim, original_shape[0])
|
|
52
|
+
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
|
|
53
|
+
|
|
54
|
+
return packed
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_autotune_config():
|
|
58
|
+
return [
|
|
59
|
+
triton.Config(
|
|
60
|
+
{
|
|
61
|
+
"BLOCK_SIZE_M": 128,
|
|
62
|
+
"BLOCK_SIZE_N": 256,
|
|
63
|
+
"BLOCK_SIZE_K": 64,
|
|
64
|
+
"GROUP_SIZE_M": 8,
|
|
65
|
+
},
|
|
66
|
+
num_stages=3,
|
|
67
|
+
num_warps=8,
|
|
68
|
+
),
|
|
69
|
+
triton.Config(
|
|
70
|
+
{
|
|
71
|
+
"BLOCK_SIZE_M": 64,
|
|
72
|
+
"BLOCK_SIZE_N": 256,
|
|
73
|
+
"BLOCK_SIZE_K": 32,
|
|
74
|
+
"GROUP_SIZE_M": 8,
|
|
75
|
+
},
|
|
76
|
+
num_stages=4,
|
|
77
|
+
num_warps=4,
|
|
78
|
+
),
|
|
79
|
+
triton.Config(
|
|
80
|
+
{
|
|
81
|
+
"BLOCK_SIZE_M": 128,
|
|
82
|
+
"BLOCK_SIZE_N": 128,
|
|
83
|
+
"BLOCK_SIZE_K": 32,
|
|
84
|
+
"GROUP_SIZE_M": 8,
|
|
85
|
+
},
|
|
86
|
+
num_stages=4,
|
|
87
|
+
num_warps=4,
|
|
88
|
+
),
|
|
89
|
+
triton.Config(
|
|
90
|
+
{
|
|
91
|
+
"BLOCK_SIZE_M": 128,
|
|
92
|
+
"BLOCK_SIZE_N": 64,
|
|
93
|
+
"BLOCK_SIZE_K": 32,
|
|
94
|
+
"GROUP_SIZE_M": 8,
|
|
95
|
+
},
|
|
96
|
+
num_stages=4,
|
|
97
|
+
num_warps=4,
|
|
98
|
+
),
|
|
99
|
+
triton.Config(
|
|
100
|
+
{
|
|
101
|
+
"BLOCK_SIZE_M": 64,
|
|
102
|
+
"BLOCK_SIZE_N": 128,
|
|
103
|
+
"BLOCK_SIZE_K": 32,
|
|
104
|
+
"GROUP_SIZE_M": 8,
|
|
105
|
+
},
|
|
106
|
+
num_stages=4,
|
|
107
|
+
num_warps=4,
|
|
108
|
+
),
|
|
109
|
+
triton.Config(
|
|
110
|
+
{
|
|
111
|
+
"BLOCK_SIZE_M": 128,
|
|
112
|
+
"BLOCK_SIZE_N": 32,
|
|
113
|
+
"BLOCK_SIZE_K": 32,
|
|
114
|
+
"GROUP_SIZE_M": 8,
|
|
115
|
+
},
|
|
116
|
+
num_stages=4,
|
|
117
|
+
num_warps=4,
|
|
118
|
+
),
|
|
119
|
+
triton.Config(
|
|
120
|
+
{
|
|
121
|
+
"BLOCK_SIZE_M": 128,
|
|
122
|
+
"BLOCK_SIZE_N": 256,
|
|
123
|
+
"BLOCK_SIZE_K": 128,
|
|
124
|
+
"GROUP_SIZE_M": 8,
|
|
125
|
+
},
|
|
126
|
+
num_stages=3,
|
|
127
|
+
num_warps=8,
|
|
128
|
+
),
|
|
129
|
+
triton.Config(
|
|
130
|
+
{
|
|
131
|
+
"BLOCK_SIZE_M": 256,
|
|
132
|
+
"BLOCK_SIZE_N": 128,
|
|
133
|
+
"BLOCK_SIZE_K": 128,
|
|
134
|
+
"GROUP_SIZE_M": 8,
|
|
135
|
+
},
|
|
136
|
+
num_stages=3,
|
|
137
|
+
num_warps=8,
|
|
138
|
+
),
|
|
139
|
+
triton.Config(
|
|
140
|
+
{
|
|
141
|
+
"BLOCK_SIZE_M": 256,
|
|
142
|
+
"BLOCK_SIZE_N": 64,
|
|
143
|
+
"BLOCK_SIZE_K": 128,
|
|
144
|
+
"GROUP_SIZE_M": 8,
|
|
145
|
+
},
|
|
146
|
+
num_stages=4,
|
|
147
|
+
num_warps=4,
|
|
148
|
+
),
|
|
149
|
+
triton.Config(
|
|
150
|
+
{
|
|
151
|
+
"BLOCK_SIZE_M": 64,
|
|
152
|
+
"BLOCK_SIZE_N": 256,
|
|
153
|
+
"BLOCK_SIZE_K": 128,
|
|
154
|
+
"GROUP_SIZE_M": 8,
|
|
155
|
+
},
|
|
156
|
+
num_stages=4,
|
|
157
|
+
num_warps=4,
|
|
158
|
+
),
|
|
159
|
+
triton.Config(
|
|
160
|
+
{
|
|
161
|
+
"BLOCK_SIZE_M": 128,
|
|
162
|
+
"BLOCK_SIZE_N": 128,
|
|
163
|
+
"BLOCK_SIZE_K": 128,
|
|
164
|
+
"GROUP_SIZE_M": 8,
|
|
165
|
+
},
|
|
166
|
+
num_stages=4,
|
|
167
|
+
num_warps=4,
|
|
168
|
+
),
|
|
169
|
+
triton.Config(
|
|
170
|
+
{
|
|
171
|
+
"BLOCK_SIZE_M": 128,
|
|
172
|
+
"BLOCK_SIZE_N": 64,
|
|
173
|
+
"BLOCK_SIZE_K": 64,
|
|
174
|
+
"GROUP_SIZE_M": 8,
|
|
175
|
+
},
|
|
176
|
+
num_stages=4,
|
|
177
|
+
num_warps=4,
|
|
178
|
+
),
|
|
179
|
+
triton.Config(
|
|
180
|
+
{
|
|
181
|
+
"BLOCK_SIZE_M": 64,
|
|
182
|
+
"BLOCK_SIZE_N": 128,
|
|
183
|
+
"BLOCK_SIZE_K": 64,
|
|
184
|
+
"GROUP_SIZE_M": 8,
|
|
185
|
+
},
|
|
186
|
+
num_stages=4,
|
|
187
|
+
num_warps=4,
|
|
188
|
+
),
|
|
189
|
+
triton.Config(
|
|
190
|
+
{
|
|
191
|
+
"BLOCK_SIZE_M": 128,
|
|
192
|
+
"BLOCK_SIZE_N": 32,
|
|
193
|
+
"BLOCK_SIZE_K": 64,
|
|
194
|
+
"GROUP_SIZE_M": 8,
|
|
195
|
+
},
|
|
196
|
+
num_stages=4,
|
|
197
|
+
num_warps=4,
|
|
198
|
+
),
|
|
199
|
+
triton.Config(
|
|
200
|
+
{
|
|
201
|
+
"BLOCK_SIZE_M": 32,
|
|
202
|
+
"BLOCK_SIZE_N": 32,
|
|
203
|
+
"BLOCK_SIZE_K": 32,
|
|
204
|
+
"GROUP_SIZE_M": 4,
|
|
205
|
+
},
|
|
206
|
+
num_stages=4,
|
|
207
|
+
num_warps=4,
|
|
208
|
+
),
|
|
209
|
+
]
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@triton.autotune(
|
|
213
|
+
configs=get_autotune_config(),
|
|
214
|
+
key=["M", "N", "K"],
|
|
215
|
+
)
|
|
216
|
+
@triton.jit
|
|
217
|
+
def matmul_kernel(
|
|
218
|
+
a_ptr,
|
|
219
|
+
b_ptr,
|
|
220
|
+
c_ptr,
|
|
221
|
+
M,
|
|
222
|
+
N,
|
|
223
|
+
K: tl.constexpr,
|
|
224
|
+
stride_am,
|
|
225
|
+
stride_ak,
|
|
226
|
+
stride_bk,
|
|
227
|
+
stride_bn,
|
|
228
|
+
stride_cm,
|
|
229
|
+
stride_cn,
|
|
230
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
231
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
232
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
233
|
+
GROUP_SIZE_M: tl.constexpr,
|
|
234
|
+
):
|
|
235
|
+
# We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
|
|
236
|
+
tl.static_assert(
|
|
237
|
+
K % (4 * BLOCK_SIZE_K) == 0,
|
|
238
|
+
"K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K",
|
|
239
|
+
)
|
|
240
|
+
# determine the block id in the 1D grid, pid <=> blockId in cuda
|
|
241
|
+
pid = tl.program_id(axis=0)
|
|
242
|
+
# number of blocks we would need in the M dimension
|
|
243
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
244
|
+
# number of blocks we would need in the N dimension
|
|
245
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
246
|
+
# blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together,
|
|
247
|
+
# and group_id calculates the group to which the current block (pid) belongs.
|
|
248
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
249
|
+
group_id = pid // num_pid_in_group
|
|
250
|
+
|
|
251
|
+
# pid of the first block in the group that the current block belongs too
|
|
252
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
|
253
|
+
|
|
254
|
+
# pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix
|
|
255
|
+
# remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix
|
|
256
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
257
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
258
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
259
|
+
|
|
260
|
+
# offs_am represent the indices of elements within the block for matrices A with respect to the M dimension
|
|
261
|
+
# offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension
|
|
262
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
263
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
264
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
265
|
+
|
|
266
|
+
"""
|
|
267
|
+
This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
|
|
268
|
+
|
|
269
|
+
As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
|
|
270
|
+
|
|
271
|
+
For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
|
|
272
|
+
For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
|
|
273
|
+
Now, let's break down the pointer generation:
|
|
274
|
+
|
|
275
|
+
offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
|
|
276
|
+
offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
|
|
277
|
+
When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
|
|
278
|
+
|
|
279
|
+
The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
|
|
280
|
+
"""
|
|
281
|
+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
282
|
+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
283
|
+
|
|
284
|
+
# An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication.
|
|
285
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
|
|
286
|
+
"""
|
|
287
|
+
We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A.
|
|
288
|
+
|
|
289
|
+
For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K).
|
|
290
|
+
Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A,
|
|
291
|
+
we still iterate over the entire first dimension of matrix B.
|
|
292
|
+
|
|
293
|
+
In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract.
|
|
294
|
+
Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop,
|
|
295
|
+
we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass.
|
|
296
|
+
"""
|
|
297
|
+
for i in range(4):
|
|
298
|
+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
299
|
+
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):
|
|
300
|
+
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
|
|
301
|
+
# load the block of matrix A
|
|
302
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
|
|
303
|
+
# load the block of matrix B
|
|
304
|
+
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)
|
|
305
|
+
# when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits
|
|
306
|
+
mask = 3 << (2 * i)
|
|
307
|
+
# we shift the results after the mask
|
|
308
|
+
b = (b_uint8 & mask) >> (2 * i)
|
|
309
|
+
# During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here
|
|
310
|
+
tensor_full = tl.full((1,), 1, dtype=tl.int8)
|
|
311
|
+
# We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows.
|
|
312
|
+
accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)
|
|
313
|
+
# we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1
|
|
314
|
+
# for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N
|
|
315
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
316
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
317
|
+
|
|
318
|
+
c = accumulator
|
|
319
|
+
# These lines compute the offsets into matrix C where the result of this block’s computation should be stored.
|
|
320
|
+
# stride_cm = N & stride_cn = 1
|
|
321
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
322
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
323
|
+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
324
|
+
# we do a boundary check to ensure only elements within matrix bounds are stored
|
|
325
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
326
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def matmul(a, b):
|
|
330
|
+
assert (
|
|
331
|
+
a.shape[1] == b.shape[0] * 4
|
|
332
|
+
), "Incompatible dimensions, the weight matrix need to be packed"
|
|
333
|
+
assert a.is_contiguous(), "Matrix A must be contiguous"
|
|
334
|
+
M, K = a.shape
|
|
335
|
+
_, N = b.shape
|
|
336
|
+
# c is in int32 to avoid any overflows or underflows
|
|
337
|
+
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
|
338
|
+
grid = lambda META: (
|
|
339
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
340
|
+
)
|
|
341
|
+
matmul_kernel[grid](
|
|
342
|
+
a,
|
|
343
|
+
b,
|
|
344
|
+
c,
|
|
345
|
+
M,
|
|
346
|
+
N,
|
|
347
|
+
K,
|
|
348
|
+
a.stride(0),
|
|
349
|
+
a.stride(1),
|
|
350
|
+
b.stride(0),
|
|
351
|
+
b.stride(1),
|
|
352
|
+
c.stride(0),
|
|
353
|
+
c.stride(1),
|
|
354
|
+
)
|
|
355
|
+
return c
|
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
|
-
from liger_kernel.ops.cross_entropy import
|
|
4
|
+
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
|
5
|
+
from liger_kernel.ops.utils import (
|
|
6
|
+
amp_custom_bwd,
|
|
7
|
+
amp_custom_fwd,
|
|
5
8
|
element_mul_kernel,
|
|
6
|
-
|
|
9
|
+
is_hip,
|
|
7
10
|
)
|
|
8
11
|
|
|
9
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -18,12 +21,12 @@ def fused_linear_cross_entropy_forward(
|
|
|
18
21
|
target,
|
|
19
22
|
bias=None,
|
|
20
23
|
ignore_index=-100,
|
|
24
|
+
lse_square_scale=0.0,
|
|
21
25
|
label_smoothing=0.0,
|
|
22
26
|
reduction="mean",
|
|
27
|
+
softcap=None,
|
|
23
28
|
):
|
|
24
|
-
dtype =
|
|
25
|
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
|
|
26
|
-
)
|
|
29
|
+
dtype = _input.dtype
|
|
27
30
|
device = _input.device
|
|
28
31
|
|
|
29
32
|
# inputs have shape: BT x H
|
|
@@ -85,14 +88,19 @@ def fused_linear_cross_entropy_forward(
|
|
|
85
88
|
Y_ptr=target_chunk,
|
|
86
89
|
Y_stride=target_chunk.stride(-1), # always 1
|
|
87
90
|
loss_ptr=loss_1d_slice,
|
|
91
|
+
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
|
|
88
92
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
89
93
|
n_cols=V,
|
|
90
94
|
n_non_ignore=n_non_ignore,
|
|
91
95
|
ignore_index=ignore_index,
|
|
96
|
+
lse_square_scale=lse_square_scale,
|
|
92
97
|
label_smoothing=label_smoothing,
|
|
93
98
|
reduction=reduction,
|
|
99
|
+
softcap=softcap if softcap is not None else 0.0,
|
|
100
|
+
RETURN_Z_LOSS=0, # False
|
|
101
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
94
102
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
95
|
-
num_warps=32,
|
|
103
|
+
num_warps=32 if not is_hip() else 16,
|
|
96
104
|
)
|
|
97
105
|
|
|
98
106
|
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
|
@@ -157,7 +165,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
157
165
|
grad_output,
|
|
158
166
|
H,
|
|
159
167
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
160
|
-
num_warps=32,
|
|
168
|
+
num_warps=32 if not is_hip() else 16,
|
|
161
169
|
)
|
|
162
170
|
|
|
163
171
|
# handle grad_weight
|
|
@@ -171,7 +179,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
171
179
|
grad_output,
|
|
172
180
|
H,
|
|
173
181
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
174
|
-
num_warps=32,
|
|
182
|
+
num_warps=32 if not is_hip() else 16,
|
|
175
183
|
)
|
|
176
184
|
|
|
177
185
|
if grad_bias is not None:
|
|
@@ -184,13 +192,14 @@ def fused_linear_cross_entropy_backward(
|
|
|
184
192
|
grad_output,
|
|
185
193
|
1,
|
|
186
194
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
187
|
-
num_warps=32,
|
|
195
|
+
num_warps=32 if not is_hip() else 16,
|
|
188
196
|
)
|
|
189
197
|
return grad_input, grad_weight, grad_bias
|
|
190
198
|
|
|
191
199
|
|
|
192
200
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
193
201
|
@staticmethod
|
|
202
|
+
@amp_custom_fwd
|
|
194
203
|
def forward(
|
|
195
204
|
ctx,
|
|
196
205
|
_input,
|
|
@@ -198,8 +207,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
198
207
|
target,
|
|
199
208
|
bias=None,
|
|
200
209
|
ignore_index=-100,
|
|
210
|
+
lse_square_scale=0.0,
|
|
201
211
|
label_smoothing=0.0,
|
|
202
212
|
reduction="mean",
|
|
213
|
+
softcap=None,
|
|
203
214
|
):
|
|
204
215
|
"""
|
|
205
216
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -219,7 +230,15 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
219
230
|
reduction: reduction to apply
|
|
220
231
|
"""
|
|
221
232
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
222
|
-
_input,
|
|
233
|
+
_input,
|
|
234
|
+
weight,
|
|
235
|
+
target,
|
|
236
|
+
bias,
|
|
237
|
+
ignore_index,
|
|
238
|
+
lse_square_scale,
|
|
239
|
+
label_smoothing,
|
|
240
|
+
reduction,
|
|
241
|
+
softcap,
|
|
223
242
|
)
|
|
224
243
|
# downcast to dtype and store for backward
|
|
225
244
|
ctx.save_for_backward(
|
|
@@ -230,9 +249,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
230
249
|
return loss
|
|
231
250
|
|
|
232
251
|
@staticmethod
|
|
252
|
+
@amp_custom_bwd
|
|
233
253
|
def backward(ctx, grad_output):
|
|
234
254
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
235
255
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
236
256
|
grad_output, grad_input, grad_weight, grad_bias
|
|
237
257
|
)
|
|
238
|
-
return (grad_input, grad_weight, None, grad_bias, None, None, None)
|
|
258
|
+
return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
|