liger-kernel 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +5 -5
- liger_kernel/ops/fused_linear_cross_entropy.py +50 -21
- liger_kernel/ops/geglu.py +6 -1
- liger_kernel/ops/rms_norm.py +142 -20
- liger_kernel/ops/rope.py +3 -3
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/monkey_patch.py +203 -10
- liger_kernel/transformers/rms_norm.py +20 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- {liger_kernel-0.1.1.dist-info → liger_kernel-0.2.0.dist-info}/METADATA +87 -25
- liger_kernel-0.2.0.dist-info/RECORD +33 -0
- liger_kernel-0.1.1.dist-info/RECORD +0 -27
- {liger_kernel-0.1.1.dist-info → liger_kernel-0.2.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.1.dist-info → liger_kernel-0.2.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.1.dist-info → liger_kernel-0.2.0.dist-info}/WHEEL +0 -0
- {liger_kernel-0.1.1.dist-info → liger_kernel-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import platform
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def print_env_report():
|
|
6
|
+
"""
|
|
7
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
8
|
+
Usage:
|
|
9
|
+
```
|
|
10
|
+
python -m liger_kernel.env_report
|
|
11
|
+
```
|
|
12
|
+
"""
|
|
13
|
+
print("Environment Report:")
|
|
14
|
+
print("-------------------")
|
|
15
|
+
print(f"Operating System: {platform.platform()}")
|
|
16
|
+
print(f"Python version: {sys.version.split()[0]}")
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
22
|
+
cuda_version = (
|
|
23
|
+
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
24
|
+
)
|
|
25
|
+
print(f"CUDA version: {cuda_version}")
|
|
26
|
+
except ImportError:
|
|
27
|
+
print("PyTorch: Not installed")
|
|
28
|
+
print("CUDA version: Unable to query")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import triton
|
|
32
|
+
|
|
33
|
+
print(f"Triton version: {triton.__version__}")
|
|
34
|
+
except ImportError:
|
|
35
|
+
print("Triton: Not installed")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import transformers
|
|
39
|
+
|
|
40
|
+
print(f"Transformers version: {transformers.__version__}")
|
|
41
|
+
except ImportError:
|
|
42
|
+
print("Transformers: Not installed")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
print_env_report()
|
|
@@ -56,7 +56,7 @@ def liger_cross_entropy_kernel(
|
|
|
56
56
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
57
57
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
58
58
|
|
|
59
|
-
# 3. [
|
|
59
|
+
# 3. [Online softmax] first pass: find max + sum
|
|
60
60
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
61
61
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
62
62
|
ori_X_y = tl.load(
|
|
@@ -73,10 +73,10 @@ def liger_cross_entropy_kernel(
|
|
|
73
73
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
74
74
|
m = m_new
|
|
75
75
|
|
|
76
|
-
# 4. [
|
|
76
|
+
# 4. [Online softmax] second pass: calculate the gradients
|
|
77
77
|
# dx_y = (softmax(x_y) - 1) / N
|
|
78
78
|
# dx_i = softmax(x_i) / N, i != y
|
|
79
|
-
# N is the number of non
|
|
79
|
+
# N is the number of non ignored elements in the batch
|
|
80
80
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
81
81
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
82
82
|
X_block = tl.load(
|
|
@@ -86,7 +86,7 @@ def liger_cross_entropy_kernel(
|
|
|
86
86
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
87
87
|
|
|
88
88
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
89
|
-
#
|
|
89
|
+
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
90
90
|
tl.debug_barrier()
|
|
91
91
|
|
|
92
92
|
# 5. Calculate the loss
|
|
@@ -196,7 +196,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
196
196
|
ignore_index=ignore_index,
|
|
197
197
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
198
|
# TODO: 32 seems to give the best performance
|
|
199
|
-
# Performance is quite
|
|
199
|
+
# Performance is quite sensitive to num_warps
|
|
200
200
|
num_warps=32,
|
|
201
201
|
)
|
|
202
202
|
|
|
@@ -11,7 +11,7 @@ MAX_FUSED_SIZE = 65536 // 2
|
|
|
11
11
|
|
|
12
12
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
13
13
|
@staticmethod
|
|
14
|
-
def forward(ctx, _input,
|
|
14
|
+
def forward(ctx, _input, weight, target, bias=None, ignore_index=-100):
|
|
15
15
|
"""
|
|
16
16
|
Fusing the last linear layer with cross-entropy loss
|
|
17
17
|
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
@@ -23,7 +23,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
23
23
|
|
|
24
24
|
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
|
|
25
25
|
target: (B*T) where each value is in [0, V-1]
|
|
26
|
-
|
|
26
|
+
weight: (V, H) where V is the number of classes
|
|
27
|
+
bias: (V) where V is the number of classes
|
|
27
28
|
ignore_index: the index to ignore in the target
|
|
28
29
|
"""
|
|
29
30
|
dtype = (
|
|
@@ -36,12 +37,12 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
36
37
|
# inputs have shape: BT x H
|
|
37
38
|
# materialized activations will have shape: BT x V
|
|
38
39
|
# the increase in memory = BT x V
|
|
39
|
-
# reduction can be achieved by
|
|
40
|
+
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
|
|
40
41
|
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
|
|
41
42
|
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
|
|
42
43
|
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
|
|
43
44
|
BT, H = _input.shape
|
|
44
|
-
V =
|
|
45
|
+
V = weight.shape[0]
|
|
45
46
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
46
47
|
|
|
47
48
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
@@ -50,9 +51,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
50
51
|
) # (BT + inc_factor - 1) // inc_factor
|
|
51
52
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
52
53
|
|
|
53
|
-
|
|
54
|
+
grad_weight = torch.zeros_like(weight, device=device)
|
|
54
55
|
grad_input = torch.zeros_like(_input, device=device)
|
|
55
|
-
|
|
56
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
56
57
|
# we use fp32 for loss accumulator
|
|
57
58
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
58
59
|
|
|
@@ -64,7 +65,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
64
65
|
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
|
|
65
66
|
|
|
66
67
|
# when doing matmul, use the original precision
|
|
67
|
-
logits_chunk = _input_chunk @
|
|
68
|
+
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
|
69
|
+
if bias is not None:
|
|
70
|
+
logits_chunk = logits_chunk + bias
|
|
68
71
|
target_chunk = target[start_idx:end_idx] # chunk_size,
|
|
69
72
|
|
|
70
73
|
n_rows = logits_chunk.shape[0]
|
|
@@ -95,39 +98,52 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
95
98
|
num_warps=32,
|
|
96
99
|
)
|
|
97
100
|
|
|
98
|
-
# gradient of logits_chunk is computed
|
|
101
|
+
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
|
99
102
|
# Following HuggingFace model source code, we do the forward and backward
|
|
100
103
|
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
|
|
101
104
|
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
|
102
105
|
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
|
103
106
|
logits_chunk = logits_chunk.to(dtype)
|
|
104
107
|
|
|
105
|
-
# gradient of logits_chunk is computed
|
|
108
|
+
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
|
|
106
109
|
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
|
107
110
|
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
|
108
111
|
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
|
|
109
112
|
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
|
110
|
-
grad_logits_chunk = logits_chunk * (
|
|
111
|
-
|
|
112
|
-
|
|
113
|
+
grad_logits_chunk = logits_chunk * (
|
|
114
|
+
n_non_ignore / total_n_non_ignore
|
|
115
|
+
) # chunk_size x V
|
|
116
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
113
117
|
torch.addmm(
|
|
114
|
-
input=
|
|
118
|
+
input=grad_weight,
|
|
115
119
|
mat1=logits_chunk.t(),
|
|
116
120
|
mat2=_input_chunk,
|
|
117
|
-
out=
|
|
121
|
+
out=grad_weight,
|
|
118
122
|
alpha=n_non_ignore / total_n_non_ignore,
|
|
119
123
|
beta=1.0,
|
|
120
124
|
)
|
|
121
125
|
|
|
126
|
+
if bias is not None:
|
|
127
|
+
torch.add(
|
|
128
|
+
input=grad_bias,
|
|
129
|
+
other=logits_chunk.sum(dim=0),
|
|
130
|
+
out=grad_bias,
|
|
131
|
+
alpha=n_non_ignore / total_n_non_ignore,
|
|
132
|
+
)
|
|
133
|
+
|
|
122
134
|
loss = torch.sum(loss_1d) / total_n_non_ignore
|
|
123
135
|
|
|
124
136
|
# downcast to dtype and store for backward
|
|
125
|
-
ctx.save_for_backward(
|
|
137
|
+
ctx.save_for_backward(
|
|
138
|
+
grad_input.detach(),
|
|
139
|
+
grad_weight.detach(),
|
|
140
|
+
grad_bias.detach() if bias is not None else None,
|
|
141
|
+
)
|
|
126
142
|
return loss
|
|
127
143
|
|
|
128
144
|
@staticmethod
|
|
129
145
|
def backward(ctx, grad_output):
|
|
130
|
-
(grad_input,
|
|
146
|
+
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
131
147
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
132
148
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
133
149
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
@@ -145,17 +161,30 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
145
161
|
num_warps=32,
|
|
146
162
|
)
|
|
147
163
|
|
|
148
|
-
# handle
|
|
149
|
-
V, H =
|
|
164
|
+
# handle grad_weight
|
|
165
|
+
V, H = grad_weight.shape
|
|
150
166
|
n_rows = V
|
|
151
167
|
|
|
152
168
|
element_mul[(n_rows,)](
|
|
153
|
-
|
|
154
|
-
|
|
169
|
+
grad_weight,
|
|
170
|
+
grad_weight.stride(-2),
|
|
155
171
|
grad_output,
|
|
156
172
|
H,
|
|
157
173
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
158
174
|
num_warps=32,
|
|
159
175
|
)
|
|
160
176
|
|
|
161
|
-
|
|
177
|
+
if grad_bias is not None:
|
|
178
|
+
V = grad_bias.shape[0]
|
|
179
|
+
n_rows = V
|
|
180
|
+
|
|
181
|
+
element_mul[(n_rows,)](
|
|
182
|
+
grad_bias,
|
|
183
|
+
grad_bias.stride(-1),
|
|
184
|
+
grad_output,
|
|
185
|
+
1,
|
|
186
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
187
|
+
num_warps=32,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return (grad_input, grad_weight, None, grad_bias, None)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -11,7 +11,12 @@ from liger_kernel.ops.utils import (
|
|
|
11
11
|
)
|
|
12
12
|
|
|
13
13
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
-
|
|
14
|
+
try:
|
|
15
|
+
# typical import path with dispatch available
|
|
16
|
+
from triton.language.extra.libdevice import tanh
|
|
17
|
+
except ModuleNotFoundError:
|
|
18
|
+
# for working with NGC containers
|
|
19
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
15
20
|
else:
|
|
16
21
|
from triton.language.math import tanh
|
|
17
22
|
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -1,8 +1,29 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
import triton.language as tl
|
|
4
6
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
calculate_settings,
|
|
9
|
+
compare_version,
|
|
10
|
+
ensure_contiguous,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
try:
|
|
15
|
+
# typical import path with dispatch available
|
|
16
|
+
from triton.language.extra.libdevice import rsqrt
|
|
17
|
+
except ModuleNotFoundError:
|
|
18
|
+
# for working with NGC containers
|
|
19
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
20
|
+
else:
|
|
21
|
+
from triton.language.math import rsqrt
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_CASTING_MODE_NONE = tl.constexpr(-1)
|
|
25
|
+
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
|
26
|
+
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
|
6
27
|
|
|
7
28
|
|
|
8
29
|
@triton.jit
|
|
@@ -17,10 +38,12 @@ def _rms_norm_forward(
|
|
|
17
38
|
r_row_stride,
|
|
18
39
|
n_cols,
|
|
19
40
|
eps,
|
|
41
|
+
offset,
|
|
42
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
20
43
|
BLOCK_SIZE: tl.constexpr,
|
|
21
44
|
):
|
|
22
45
|
"""
|
|
23
|
-
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
|
|
46
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
24
47
|
|
|
25
48
|
Reference:
|
|
26
49
|
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
@@ -37,17 +60,33 @@ def _rms_norm_forward(
|
|
|
37
60
|
r_ptr += row_idx * r_row_stride
|
|
38
61
|
|
|
39
62
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
63
|
+
X_row_dtype = X_row.dtype
|
|
40
64
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
41
65
|
|
|
66
|
+
# On Llama, only inv_rms is computed on fp32
|
|
67
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
68
|
+
X_row = X_row.to(tl.float32)
|
|
69
|
+
|
|
70
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
71
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
72
|
+
W_row = W_row.to(tl.float32)
|
|
73
|
+
X_row = X_row.to(tl.float32)
|
|
74
|
+
|
|
42
75
|
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
43
|
-
inv_rms =
|
|
76
|
+
inv_rms = rsqrt(mean_square + eps)
|
|
44
77
|
|
|
45
78
|
# We can save time by caching rms with minimal memory overhead
|
|
46
79
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
47
80
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
48
81
|
tl.store(r_ptr, inv_rms)
|
|
49
82
|
|
|
50
|
-
|
|
83
|
+
X_row = X_row * inv_rms
|
|
84
|
+
|
|
85
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
86
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
87
|
+
X_row = X_row.to(X_row_dtype)
|
|
88
|
+
|
|
89
|
+
Y_row = X_row * (offset + W_row)
|
|
51
90
|
|
|
52
91
|
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
53
92
|
|
|
@@ -66,10 +105,12 @@ def _rms_norm_backward(
|
|
|
66
105
|
dW_row_stride,
|
|
67
106
|
n_cols,
|
|
68
107
|
eps,
|
|
108
|
+
offset,
|
|
109
|
+
casting_mode: tl.constexpr,
|
|
69
110
|
BLOCK_SIZE: tl.constexpr,
|
|
70
111
|
):
|
|
71
112
|
"""
|
|
72
|
-
dx = (1 / RMS) * [dy * w
|
|
113
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
73
114
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
74
115
|
"""
|
|
75
116
|
|
|
@@ -85,33 +126,95 @@ def _rms_norm_backward(
|
|
|
85
126
|
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
|
|
86
127
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
87
128
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
129
|
+
original_x_dtype = X_row.dtype
|
|
88
130
|
|
|
89
131
|
# Get cached rms
|
|
90
132
|
inv_rms_row = tl.load(r_ptr)
|
|
91
133
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
*
|
|
99
|
-
|
|
100
|
-
|
|
134
|
+
W_row = W_row + offset
|
|
135
|
+
|
|
136
|
+
# Different bacward graphs for different casting modes
|
|
137
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
138
|
+
X_row = X_row.to(tl.float32)
|
|
139
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
140
|
+
dX_row = inv_rms_row * m
|
|
141
|
+
|
|
142
|
+
dX_row += (inv_rms_row) * (
|
|
143
|
+
-(1 / n_cols)
|
|
144
|
+
* inv_rms_row
|
|
145
|
+
* inv_rms_row
|
|
146
|
+
* tl.sum(m * X_row, axis=0)
|
|
147
|
+
* X_row
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
151
|
+
dY_row, W_row, X_row = (
|
|
152
|
+
dY_row.to(tl.float32),
|
|
153
|
+
W_row.to(tl.float32),
|
|
154
|
+
X_row.to(tl.float32),
|
|
155
|
+
)
|
|
156
|
+
dX_row = inv_rms_row * dY_row * W_row
|
|
157
|
+
|
|
158
|
+
dX_row += (inv_rms_row) * (
|
|
159
|
+
-(1 / n_cols)
|
|
160
|
+
* inv_rms_row
|
|
161
|
+
* inv_rms_row
|
|
162
|
+
* tl.sum(dY_row * W_row * X_row, axis=0)
|
|
163
|
+
* X_row
|
|
164
|
+
)
|
|
101
165
|
|
|
102
166
|
# calculate the gradient of W
|
|
103
|
-
|
|
167
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
+
dW_row = dY_row * (X_row * inv_rms_row).to(original_x_dtype)
|
|
169
|
+
else:
|
|
170
|
+
# here X_row is already in fp32 (see previous if block)
|
|
171
|
+
dW_row = dY_row * (X_row * inv_rms_row)
|
|
172
|
+
|
|
173
|
+
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
104
174
|
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
|
|
105
175
|
|
|
106
176
|
|
|
177
|
+
_str_to_casting_mode = {
|
|
178
|
+
"llama": _CASTING_MODE_LLAMA.value,
|
|
179
|
+
"gemma": _CASTING_MODE_GEMMA.value,
|
|
180
|
+
"none": _CASTING_MODE_NONE.value,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
|
|
107
184
|
class LigerRMSNormFunction(torch.autograd.Function):
|
|
185
|
+
"""
|
|
186
|
+
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
|
|
187
|
+
weight tensor `W`, with an optional offset and casting mode.
|
|
188
|
+
|
|
189
|
+
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
|
|
190
|
+
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
|
|
191
|
+
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
|
|
192
|
+
|
|
193
|
+
In addition, different models cast their inputs at different places during RMSNorm computation. For
|
|
194
|
+
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
|
|
195
|
+
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
|
|
196
|
+
support the following casting modes (they match HuggingFace Transformers' implementations):
|
|
197
|
+
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
|
198
|
+
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
|
199
|
+
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
|
200
|
+
"""
|
|
201
|
+
|
|
108
202
|
@staticmethod
|
|
109
203
|
@ensure_contiguous
|
|
110
|
-
def forward(ctx, X, W, eps):
|
|
204
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
|
|
111
205
|
"""
|
|
112
206
|
X: (B, T, H) or (BxT, H)
|
|
113
207
|
W: (H,)
|
|
114
208
|
"""
|
|
209
|
+
if not isinstance(casting_mode, int):
|
|
210
|
+
assert (
|
|
211
|
+
casting_mode in _str_to_casting_mode
|
|
212
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
213
|
+
casting_mode = _str_to_casting_mode[casting_mode]
|
|
214
|
+
else:
|
|
215
|
+
assert (
|
|
216
|
+
casting_mode in _str_to_casting_mode.values()
|
|
217
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
115
218
|
|
|
116
219
|
shape = X.shape
|
|
117
220
|
dim = shape[-1]
|
|
@@ -121,7 +224,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
121
224
|
|
|
122
225
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
123
226
|
# r is to cache (1/rms) for each row
|
|
124
|
-
r
|
|
227
|
+
# r is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
|
228
|
+
r_dtype = (
|
|
229
|
+
torch.float32
|
|
230
|
+
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
|
231
|
+
else X.dtype
|
|
232
|
+
)
|
|
233
|
+
r = torch.empty(n_rows, dtype=r_dtype, device=X.device)
|
|
125
234
|
|
|
126
235
|
# Check constraints.
|
|
127
236
|
assert (
|
|
@@ -139,10 +248,14 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
139
248
|
r.stride(0),
|
|
140
249
|
n_cols,
|
|
141
250
|
eps,
|
|
251
|
+
offset,
|
|
252
|
+
casting_mode,
|
|
142
253
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
143
254
|
num_warps=num_warps,
|
|
144
255
|
)
|
|
145
256
|
ctx.eps = eps
|
|
257
|
+
ctx.offset = offset
|
|
258
|
+
ctx.casting_mode = casting_mode
|
|
146
259
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
147
260
|
ctx.num_warps = num_warps
|
|
148
261
|
|
|
@@ -161,7 +274,14 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
161
274
|
dY = dY.view(-1, dim)
|
|
162
275
|
X, W, r = ctx.saved_tensors
|
|
163
276
|
n_rows, n_cols = dY.shape
|
|
164
|
-
dW = torch.
|
|
277
|
+
dW = torch.empty_like(
|
|
278
|
+
X,
|
|
279
|
+
dtype=(
|
|
280
|
+
torch.float32
|
|
281
|
+
if ctx.casting_mode == _CASTING_MODE_GEMMA.value
|
|
282
|
+
else W.dtype
|
|
283
|
+
),
|
|
284
|
+
)
|
|
165
285
|
|
|
166
286
|
# Here we use dY to store the value of dX to save memory
|
|
167
287
|
_rms_norm_backward[(n_rows,)](
|
|
@@ -177,9 +297,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
177
297
|
dW.stride(0),
|
|
178
298
|
n_cols,
|
|
179
299
|
ctx.eps,
|
|
300
|
+
ctx.offset,
|
|
301
|
+
ctx.casting_mode,
|
|
180
302
|
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
|
181
303
|
num_warps=ctx.num_warps,
|
|
182
304
|
)
|
|
183
305
|
dX = dY.view(*shape)
|
|
184
|
-
dW = torch.sum(dW, dim=0)
|
|
185
|
-
return dX, dW, None
|
|
306
|
+
dW = torch.sum(dW, dim=0).to(W.dtype)
|
|
307
|
+
return dX, dW, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -13,8 +13,8 @@ def _triton_rope(
|
|
|
13
13
|
cos_row_stride,
|
|
14
14
|
sin,
|
|
15
15
|
sin_row_stride,
|
|
16
|
+
sl,
|
|
16
17
|
bs: tl.constexpr,
|
|
17
|
-
sl: tl.constexpr,
|
|
18
18
|
n_qh: tl.constexpr,
|
|
19
19
|
n_kh: tl.constexpr,
|
|
20
20
|
hd: tl.constexpr,
|
|
@@ -168,8 +168,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
168
168
|
cos.stride(-2),
|
|
169
169
|
sin,
|
|
170
170
|
sin.stride(-2),
|
|
171
|
-
batch_size,
|
|
172
171
|
seq_len,
|
|
172
|
+
batch_size,
|
|
173
173
|
n_q_head,
|
|
174
174
|
n_kv_head,
|
|
175
175
|
head_dim,
|
|
@@ -219,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
219
219
|
cos.stride(-2),
|
|
220
220
|
sin,
|
|
221
221
|
sin.stride(-2),
|
|
222
|
-
batch_size,
|
|
223
222
|
seq_len,
|
|
223
|
+
batch_size,
|
|
224
224
|
n_q_head,
|
|
225
225
|
n_kv_head,
|
|
226
226
|
head_dim,
|
|
@@ -1,6 +1,12 @@
|
|
|
1
|
+
from liger_kernel.transformers.auto_model import ( # noqa: F401
|
|
2
|
+
AutoLigerKernelForCausalLM,
|
|
3
|
+
)
|
|
1
4
|
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
|
|
2
5
|
apply_liger_kernel_to_gemma,
|
|
6
|
+
apply_liger_kernel_to_gemma2,
|
|
3
7
|
apply_liger_kernel_to_llama,
|
|
4
8
|
apply_liger_kernel_to_mistral,
|
|
5
9
|
apply_liger_kernel_to_mixtral,
|
|
10
|
+
apply_liger_kernel_to_phi3,
|
|
11
|
+
apply_liger_kernel_to_qwen2,
|
|
6
12
|
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
2
|
+
|
|
3
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_model_config(model_dir, **model_init_kwargs):
|
|
7
|
+
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
|
|
8
|
+
return config
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
12
|
+
"""
|
|
13
|
+
This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
|
|
14
|
+
if applicable.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
19
|
+
model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
|
|
20
|
+
|
|
21
|
+
# Determine the model type and apply the Liger Kernel if applicable
|
|
22
|
+
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
|
|
23
|
+
model_type = model_config.model_type
|
|
24
|
+
_apply_liger_kernel(model_type, **kwargs)
|
|
25
|
+
|
|
26
|
+
# Retain only the keyword args present in the model configuration
|
|
27
|
+
for k in list(kwargs.keys()):
|
|
28
|
+
if k not in model_config.__dict__:
|
|
29
|
+
del kwargs[k]
|
|
30
|
+
|
|
31
|
+
return super().from_pretrained(
|
|
32
|
+
pretrained_model_name_or_path, *model_args, **kwargs
|
|
33
|
+
)
|
|
@@ -9,7 +9,7 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
|
|
|
9
9
|
def __init__(self, *args, **kwargs):
|
|
10
10
|
super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
11
11
|
|
|
12
|
-
def forward(self, lin_weight, _input, target):
|
|
12
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
13
13
|
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
14
|
-
_input, lin_weight, target, self.ignore_index
|
|
14
|
+
_input, lin_weight, target, bias, self.ignore_index
|
|
15
15
|
)
|
|
@@ -13,8 +13,10 @@ class LigerGEGLUMLP(nn.Module):
|
|
|
13
13
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
14
14
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
15
15
|
# TODO: support exact GELU
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
# Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
|
|
17
|
+
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
|
|
18
|
+
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
|
|
19
|
+
# So we can safely assume we use tanh approximation form all the time
|
|
18
20
|
|
|
19
21
|
def forward(self, x):
|
|
20
22
|
|