liger-kernel 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_linear_cross_entropy.py +1 -1
- liger_kernel/ops/geglu.py +2 -2
- liger_kernel/ops/kl_div.py +43 -32
- liger_kernel/ops/swiglu.py +2 -2
- liger_kernel/transformers/auto_model.py +18 -6
- liger_kernel/transformers/kl_div.py +3 -2
- liger_kernel/transformers/monkey_patch.py +96 -122
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/METADATA +15 -8
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/RECORD +13 -13
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -97,7 +97,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
97
97
|
|
|
98
98
|
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
|
99
99
|
# Following HuggingFace model source code, we do the forward and backward
|
|
100
|
-
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size)
|
|
100
|
+
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
|
|
101
101
|
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
|
102
102
|
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
|
103
103
|
logits_chunk = logits_chunk.to(dtype)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -25,7 +25,7 @@ else:
|
|
|
25
25
|
def _geglu_tanh_forward_kernel(
|
|
26
26
|
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
27
27
|
):
|
|
28
|
-
program_id = tl.program_id(0)
|
|
28
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
29
29
|
|
|
30
30
|
# locate start index
|
|
31
31
|
a += program_id * stride
|
|
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
52
52
|
def _geglu_tanh_backward_kernel(
|
|
53
53
|
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
54
54
|
):
|
|
55
|
-
program_id = tl.program_id(0)
|
|
55
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
56
56
|
|
|
57
57
|
# locate start index
|
|
58
58
|
dc += program_id * stride
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -45,6 +45,7 @@ def _kldiv_kernel_forward(
|
|
|
45
45
|
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
|
46
46
|
loss_stride, # int, output stride
|
|
47
47
|
n_cols, # int, number of columns in the input tensor
|
|
48
|
+
eps,
|
|
48
49
|
BLOCK_SIZE: tl.constexpr,
|
|
49
50
|
log_target: tl.constexpr = False,
|
|
50
51
|
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
@@ -56,6 +57,7 @@ def _kldiv_kernel_forward(
|
|
|
56
57
|
|
|
57
58
|
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
58
59
|
|
|
60
|
+
loss_sum = 0.0
|
|
59
61
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
60
62
|
offsets = i + base_offsets
|
|
61
63
|
mask = offsets < n_cols
|
|
@@ -65,32 +67,33 @@ def _kldiv_kernel_forward(
|
|
|
65
67
|
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
|
66
68
|
# We compute KL(y_true || y) with y in the log-space
|
|
67
69
|
if not log_target:
|
|
68
|
-
loss = y_true * (tl.log(y_true) - y)
|
|
70
|
+
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
|
|
69
71
|
else:
|
|
70
72
|
loss = tl.exp(y_true) * (y_true - y)
|
|
71
73
|
|
|
72
74
|
if reduction == _REDUCTION_MODE_NONE:
|
|
73
75
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
74
76
|
else:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
77
|
+
loss_sum += tl.sum(loss, axis=0)
|
|
78
|
+
|
|
79
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
80
|
+
tl.store(loss_ptr, loss_sum)
|
|
78
81
|
|
|
79
82
|
|
|
80
83
|
@triton.jit
|
|
81
84
|
def _kldiv_kernel_backward(
|
|
82
|
-
input_ptr,
|
|
83
|
-
input_stride,
|
|
84
85
|
target_ptr,
|
|
85
86
|
target_stride,
|
|
87
|
+
new_grads_ptr,
|
|
88
|
+
new_grads_stride,
|
|
86
89
|
n_cols,
|
|
87
90
|
BLOCK_SIZE: tl.constexpr,
|
|
88
91
|
log_target: tl.constexpr = False,
|
|
89
92
|
):
|
|
90
93
|
pid = tl.program_id(0).to(tl.int64)
|
|
91
94
|
|
|
92
|
-
input_ptr += pid * input_stride
|
|
93
95
|
target_ptr += pid * target_stride
|
|
96
|
+
new_grads_ptr += pid * new_grads_stride
|
|
94
97
|
|
|
95
98
|
offsets = tl.arange(0, BLOCK_SIZE)
|
|
96
99
|
mask = offsets < n_cols
|
|
@@ -106,19 +109,19 @@ def _kldiv_kernel_backward(
|
|
|
106
109
|
else:
|
|
107
110
|
res = -tl.exp(target)
|
|
108
111
|
|
|
109
|
-
tl.store(
|
|
112
|
+
tl.store(new_grads_ptr + offsets, res, mask=mask)
|
|
110
113
|
|
|
111
114
|
|
|
112
|
-
def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [
|
|
113
|
-
|
|
115
|
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
116
|
+
BT, V = y_pred.shape
|
|
114
117
|
|
|
115
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(
|
|
118
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
116
119
|
num_warps = get_num_warps(BLOCK_SIZE)
|
|
117
120
|
|
|
118
|
-
grid = (
|
|
121
|
+
grid = (BT,)
|
|
119
122
|
reduction = _str_to_reduction_mode[reduction]
|
|
120
123
|
|
|
121
|
-
out_size = (
|
|
124
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
122
125
|
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
|
123
126
|
|
|
124
127
|
_kldiv_kernel_forward[grid](
|
|
@@ -128,7 +131,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
|
|
|
128
131
|
y_true.stride(0),
|
|
129
132
|
output_tensor,
|
|
130
133
|
output_tensor.stride(0),
|
|
131
|
-
|
|
134
|
+
V,
|
|
135
|
+
eps=eps,
|
|
132
136
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
133
137
|
num_warps=num_warps,
|
|
134
138
|
log_target=log_target,
|
|
@@ -139,30 +143,30 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
|
|
|
139
143
|
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
140
144
|
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
|
141
145
|
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
142
|
-
return output_tensor.sum() /
|
|
146
|
+
return output_tensor.sum() / BT
|
|
143
147
|
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
144
148
|
return output_tensor.sum(dim=0)
|
|
145
149
|
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
146
|
-
return output_tensor.
|
|
150
|
+
return output_tensor.sum() / (BT * V)
|
|
147
151
|
else:
|
|
148
152
|
return output_tensor
|
|
149
153
|
|
|
150
154
|
|
|
151
|
-
def kldiv_backward_triton(
|
|
152
|
-
|
|
155
|
+
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
156
|
+
BT, V = target.shape
|
|
153
157
|
|
|
154
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(
|
|
158
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
155
159
|
num_warps = get_num_warps(BLOCK_SIZE)
|
|
156
160
|
|
|
157
|
-
grid = (
|
|
161
|
+
grid = (BT,)
|
|
158
162
|
|
|
159
163
|
# We store the gradients in-place in the input tensor
|
|
160
164
|
_kldiv_kernel_backward[grid](
|
|
161
|
-
input,
|
|
162
|
-
input.stride(0),
|
|
163
165
|
target,
|
|
164
166
|
target.stride(0),
|
|
165
|
-
|
|
167
|
+
new_grads,
|
|
168
|
+
new_grads.stride(0),
|
|
169
|
+
V,
|
|
166
170
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
167
171
|
num_warps=num_warps,
|
|
168
172
|
log_target=log_target,
|
|
@@ -170,9 +174,9 @@ def kldiv_backward_triton(input, target, grad_output, log_target):
|
|
|
170
174
|
|
|
171
175
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
172
176
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
173
|
-
return
|
|
177
|
+
return new_grads
|
|
174
178
|
|
|
175
|
-
return
|
|
179
|
+
return new_grads * grad_output
|
|
176
180
|
|
|
177
181
|
|
|
178
182
|
class LigerKLDivLossFunction(torch.autograd.Function):
|
|
@@ -196,6 +200,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
196
200
|
y_true: torch.Tensor,
|
|
197
201
|
reduction: REDUCTION_LITERAL = "batchmean",
|
|
198
202
|
log_target: bool = False,
|
|
203
|
+
eps: float = 1e-10,
|
|
199
204
|
) -> torch.Tensor:
|
|
200
205
|
"""A forward pass for the KL Divergence Loss.
|
|
201
206
|
|
|
@@ -205,15 +210,16 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
205
210
|
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
|
206
211
|
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
|
207
212
|
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
|
213
|
+
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
|
|
208
214
|
|
|
209
215
|
Returns:
|
|
210
216
|
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
|
211
217
|
"""
|
|
212
|
-
ctx.save_for_backward(
|
|
218
|
+
ctx.save_for_backward(y_true)
|
|
213
219
|
ctx.reduction = reduction
|
|
214
220
|
ctx.log_target = log_target
|
|
215
221
|
return kldiv_forward_triton(
|
|
216
|
-
y_pred, y_true, log_target=log_target, reduction=reduction
|
|
222
|
+
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
|
217
223
|
)
|
|
218
224
|
|
|
219
225
|
@staticmethod
|
|
@@ -226,22 +232,27 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
226
232
|
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
227
233
|
|
|
228
234
|
Returns:
|
|
229
|
-
tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
235
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
230
236
|
"""
|
|
231
|
-
|
|
237
|
+
(y_true,) = ctx.saved_tensors
|
|
238
|
+
|
|
239
|
+
new_grads = torch.empty_like(y_true)
|
|
232
240
|
|
|
233
|
-
derivative = kldiv_backward_triton(
|
|
241
|
+
derivative = kldiv_backward_triton(
|
|
242
|
+
y_true, grad_output, new_grads, ctx.log_target
|
|
243
|
+
)
|
|
234
244
|
|
|
235
245
|
if ctx.reduction == "batchmean":
|
|
236
|
-
derivative = derivative /
|
|
246
|
+
derivative = derivative / y_true.shape[0]
|
|
237
247
|
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
|
238
248
|
pass
|
|
239
249
|
elif ctx.reduction == "mean":
|
|
240
|
-
derivative = derivative / (
|
|
250
|
+
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
|
|
241
251
|
|
|
242
252
|
return (
|
|
243
253
|
derivative,
|
|
244
254
|
None,
|
|
245
255
|
None,
|
|
246
256
|
None,
|
|
257
|
+
None,
|
|
247
258
|
)
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -14,7 +14,7 @@ def silu(x):
|
|
|
14
14
|
def _swiglu_forward_kernel(
|
|
15
15
|
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
16
|
):
|
|
17
|
-
program_id = tl.program_id(0)
|
|
17
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
18
18
|
|
|
19
19
|
# locate start index
|
|
20
20
|
a_ptr += program_id * stride
|
|
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
|
|
|
35
35
|
def _swiglu_backward_kernel(
|
|
36
36
|
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
37
|
):
|
|
38
|
-
program_id = tl.program_id(0)
|
|
38
|
+
program_id = tl.program_id(0).cast(tl.int64)
|
|
39
39
|
|
|
40
40
|
# locate start index
|
|
41
41
|
dc_ptr += program_id * stride
|
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
1
3
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
2
4
|
|
|
3
|
-
from liger_kernel.transformers.monkey_patch import
|
|
5
|
+
from liger_kernel.transformers.monkey_patch import (
|
|
6
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN,
|
|
7
|
+
_apply_liger_kernel,
|
|
8
|
+
)
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
def _get_model_config(model_dir, **model_init_kwargs):
|
|
@@ -21,13 +26,20 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
|
21
26
|
# Determine the model type and apply the Liger Kernel if applicable
|
|
22
27
|
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
|
|
23
28
|
model_type = model_config.model_type
|
|
29
|
+
|
|
24
30
|
_apply_liger_kernel(model_type, **kwargs)
|
|
25
31
|
|
|
26
|
-
#
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
|
|
33
|
+
# model initialization errors otherwise
|
|
34
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
35
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
36
|
+
|
|
37
|
+
applicable_kwargs = {
|
|
38
|
+
key: value
|
|
39
|
+
for key, value in kwargs.items()
|
|
40
|
+
if key not in apply_fn_signature.parameters
|
|
41
|
+
}
|
|
30
42
|
|
|
31
43
|
return super().from_pretrained(
|
|
32
|
-
pretrained_model_name_or_path, *model_args, **
|
|
44
|
+
pretrained_model_name_or_path, *model_args, **applicable_kwargs
|
|
33
45
|
)
|
|
@@ -4,10 +4,11 @@ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class LigerKLDIVLoss(nn.KLDivLoss):
|
|
7
|
-
def __init__(self, *args, **kwargs):
|
|
7
|
+
def __init__(self, eps: float = 1e-10, *args, **kwargs):
|
|
8
8
|
super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
|
|
9
|
+
self.eps = eps
|
|
9
10
|
|
|
10
11
|
def forward(self, y_pred, y_true):
|
|
11
12
|
return LigerKLDivLossFunction.apply(
|
|
12
|
-
y_pred, y_true, self.reduction, self.log_target
|
|
13
|
+
y_pred, y_true, self.reduction, self.log_target, self.eps
|
|
13
14
|
)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import logging
|
|
3
3
|
from functools import partial
|
|
4
|
+
from typing import Callable
|
|
4
5
|
|
|
5
|
-
from
|
|
6
|
-
from transformers import PretrainedConfig, PreTrainedModel
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
7
|
|
|
8
8
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
9
9
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
@@ -25,6 +25,30 @@ from liger_kernel.transformers.swiglu import (
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
29
|
+
# Binds a new method to a module instance so that self is passed as the first argument
|
|
30
|
+
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
|
|
34
|
+
module.offset = offset
|
|
35
|
+
module.casting_mode = casting_mode
|
|
36
|
+
module.variance_epsilon = (
|
|
37
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
38
|
+
)
|
|
39
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
40
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _patch_layer_norm_module(module, eps=1e-6):
|
|
44
|
+
module.variance_epsilon = (
|
|
45
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
46
|
+
)
|
|
47
|
+
module.hidden_size = module.normalized_shape
|
|
48
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
49
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
50
|
+
|
|
51
|
+
|
|
28
52
|
def apply_liger_kernel_to_llama(
|
|
29
53
|
rope: bool = True,
|
|
30
54
|
cross_entropy: bool = False,
|
|
@@ -69,7 +93,6 @@ def apply_liger_kernel_to_llama(
|
|
|
69
93
|
if model is not None:
|
|
70
94
|
# The model instance already exists, so we need to additionally patch the
|
|
71
95
|
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
72
|
-
config: PretrainedConfig = model.config
|
|
73
96
|
|
|
74
97
|
if hasattr(model, "model"):
|
|
75
98
|
# The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
|
|
@@ -81,22 +104,17 @@ def apply_liger_kernel_to_llama(
|
|
|
81
104
|
# Direct LlamaModel
|
|
82
105
|
base_model = model
|
|
83
106
|
|
|
84
|
-
torch_dtype = config.torch_dtype
|
|
85
107
|
if rms_norm:
|
|
86
|
-
base_model.norm
|
|
87
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
88
|
-
).to(torch_dtype)
|
|
108
|
+
_patch_rms_norm_module(base_model.norm)
|
|
89
109
|
|
|
90
110
|
for decoder_layer in base_model.layers:
|
|
91
111
|
if swiglu:
|
|
92
|
-
|
|
112
|
+
_bind_method_to_module(
|
|
113
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
114
|
+
)
|
|
93
115
|
if rms_norm:
|
|
94
|
-
decoder_layer.input_layernorm
|
|
95
|
-
|
|
96
|
-
).to(torch_dtype)
|
|
97
|
-
decoder_layer.post_attention_layernorm = LigerRMSNorm(
|
|
98
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
99
|
-
).to(torch_dtype)
|
|
116
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
117
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
100
118
|
|
|
101
119
|
|
|
102
120
|
def apply_liger_kernel_to_mistral(
|
|
@@ -143,7 +161,6 @@ def apply_liger_kernel_to_mistral(
|
|
|
143
161
|
if model is not None:
|
|
144
162
|
# The model instance already exists, so we need to additionally patch the
|
|
145
163
|
# instance variables that reference already-instantiated modules
|
|
146
|
-
config: PretrainedConfig = model.config
|
|
147
164
|
|
|
148
165
|
if hasattr(model, "model"):
|
|
149
166
|
# The case for MistralForCausalLM, MistralForTokenClassification for example
|
|
@@ -152,22 +169,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
152
169
|
# Direct MistralModel
|
|
153
170
|
base_model = model
|
|
154
171
|
|
|
155
|
-
torch_dtype = config.torch_dtype
|
|
156
172
|
if rms_norm:
|
|
157
|
-
base_model.norm
|
|
158
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
159
|
-
).to(torch_dtype)
|
|
173
|
+
_patch_rms_norm_module(base_model.norm)
|
|
160
174
|
|
|
161
175
|
for decoder_layer in base_model.layers:
|
|
162
176
|
if swiglu:
|
|
163
|
-
|
|
177
|
+
_bind_method_to_module(
|
|
178
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
179
|
+
)
|
|
164
180
|
if rms_norm:
|
|
165
|
-
decoder_layer.input_layernorm
|
|
166
|
-
|
|
167
|
-
).to(torch_dtype)
|
|
168
|
-
decoder_layer.post_attention_layernorm = LigerRMSNorm(
|
|
169
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
170
|
-
).to(torch_dtype)
|
|
181
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
182
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
171
183
|
|
|
172
184
|
|
|
173
185
|
def apply_liger_kernel_to_mixtral(
|
|
@@ -214,7 +226,6 @@ def apply_liger_kernel_to_mixtral(
|
|
|
214
226
|
if model is not None:
|
|
215
227
|
# The model instance already exists, so we need to additionally patch the
|
|
216
228
|
# instance variables that reference already-instantiated modules
|
|
217
|
-
config: PretrainedConfig = model.config
|
|
218
229
|
|
|
219
230
|
if hasattr(model, "model"):
|
|
220
231
|
# The case for MixtralForCausalLM, MixtralForTokenClassification for example
|
|
@@ -223,29 +234,18 @@ def apply_liger_kernel_to_mixtral(
|
|
|
223
234
|
# Direct MixtralModel
|
|
224
235
|
base_model = model
|
|
225
236
|
|
|
226
|
-
torch_dtype = config.torch_dtype
|
|
227
237
|
if rms_norm:
|
|
228
|
-
base_model.norm
|
|
229
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
230
|
-
).to(torch_dtype)
|
|
238
|
+
_patch_rms_norm_module(base_model.norm)
|
|
231
239
|
|
|
232
240
|
for decoder_layer in base_model.layers:
|
|
233
241
|
if swiglu:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
for _ in range(block_sparse_moe.num_experts)
|
|
239
|
-
]
|
|
240
|
-
)
|
|
241
|
-
decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype)
|
|
242
|
+
for expert in decoder_layer.block_sparse_moe.experts:
|
|
243
|
+
_bind_method_to_module(
|
|
244
|
+
expert, "forward", LigerBlockSparseTop2MLP.forward
|
|
245
|
+
)
|
|
242
246
|
if rms_norm:
|
|
243
|
-
decoder_layer.input_layernorm
|
|
244
|
-
|
|
245
|
-
).to(torch_dtype)
|
|
246
|
-
decoder_layer.post_attention_layernorm = LigerRMSNorm(
|
|
247
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
248
|
-
).to(torch_dtype)
|
|
247
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
248
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
249
249
|
|
|
250
250
|
|
|
251
251
|
def apply_liger_kernel_to_gemma(
|
|
@@ -282,6 +282,9 @@ def apply_liger_kernel_to_gemma(
|
|
|
282
282
|
LigerRMSNormForGemma = partial(
|
|
283
283
|
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
284
284
|
)
|
|
285
|
+
_patch_rms_norm_module_for_gemma = partial(
|
|
286
|
+
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
|
287
|
+
)
|
|
285
288
|
|
|
286
289
|
if rope:
|
|
287
290
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -297,7 +300,6 @@ def apply_liger_kernel_to_gemma(
|
|
|
297
300
|
if model is not None:
|
|
298
301
|
# The model instance already exists, so we need to additionally patch the
|
|
299
302
|
# instance variables that reference already-instantiated modules
|
|
300
|
-
config: PretrainedConfig = model.config
|
|
301
303
|
|
|
302
304
|
if hasattr(model, "model"):
|
|
303
305
|
# The case for GemmaForCausalLM, GemmaForTokenClassification for example
|
|
@@ -306,22 +308,17 @@ def apply_liger_kernel_to_gemma(
|
|
|
306
308
|
# Direct GemmaModel
|
|
307
309
|
base_model = model
|
|
308
310
|
|
|
309
|
-
torch_dtype = config.torch_dtype
|
|
310
311
|
if rms_norm:
|
|
311
|
-
base_model.norm
|
|
312
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
313
|
-
).to(torch_dtype)
|
|
312
|
+
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
314
313
|
|
|
315
314
|
for decoder_layer in base_model.layers:
|
|
316
315
|
if geglu:
|
|
317
|
-
|
|
316
|
+
_bind_method_to_module(
|
|
317
|
+
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
318
|
+
)
|
|
318
319
|
if rms_norm:
|
|
319
|
-
decoder_layer.input_layernorm
|
|
320
|
-
|
|
321
|
-
).to(torch_dtype)
|
|
322
|
-
decoder_layer.post_attention_layernorm = LigerRMSNormForGemma(
|
|
323
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
324
|
-
).to(torch_dtype)
|
|
320
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
321
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
325
322
|
|
|
326
323
|
|
|
327
324
|
def apply_liger_kernel_to_gemma2(
|
|
@@ -343,10 +340,15 @@ def apply_liger_kernel_to_gemma2(
|
|
|
343
340
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
344
341
|
loaded. Default is None.
|
|
345
342
|
"""
|
|
346
|
-
print("Got here!")
|
|
347
343
|
from transformers.models.gemma2 import modeling_gemma2
|
|
348
344
|
|
|
349
|
-
LigerRMSNormForGemma2 = partial(
|
|
345
|
+
LigerRMSNormForGemma2 = partial(
|
|
346
|
+
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
347
|
+
)
|
|
348
|
+
_patch_rms_norm_module_for_gemma2 = partial(
|
|
349
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
350
|
+
)
|
|
351
|
+
|
|
350
352
|
if rope:
|
|
351
353
|
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
352
354
|
if rms_norm:
|
|
@@ -360,7 +362,6 @@ def apply_liger_kernel_to_gemma2(
|
|
|
360
362
|
if model is not None:
|
|
361
363
|
# The model instance already exists, so we need to additionally patch the
|
|
362
364
|
# instance variables that reference already-instantiated modules
|
|
363
|
-
config: PretrainedConfig = model.config
|
|
364
365
|
|
|
365
366
|
if hasattr(model, "model"):
|
|
366
367
|
# The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
|
|
@@ -369,28 +370,25 @@ def apply_liger_kernel_to_gemma2(
|
|
|
369
370
|
# Direct Gemma2Model
|
|
370
371
|
base_model = model
|
|
371
372
|
|
|
372
|
-
torch_dtype = config.torch_dtype
|
|
373
373
|
if rms_norm:
|
|
374
|
-
base_model.norm
|
|
375
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
376
|
-
).to(torch_dtype)
|
|
374
|
+
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
377
375
|
|
|
378
376
|
for decoder_layer in base_model.layers:
|
|
379
377
|
if geglu:
|
|
380
|
-
|
|
378
|
+
_bind_method_to_module(
|
|
379
|
+
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
380
|
+
)
|
|
381
381
|
if rms_norm:
|
|
382
|
-
decoder_layer.input_layernorm
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
393
|
-
).to(torch_dtype)
|
|
382
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
383
|
+
_patch_rms_norm_module_for_gemma2(
|
|
384
|
+
decoder_layer.post_attention_layernorm
|
|
385
|
+
)
|
|
386
|
+
_patch_rms_norm_module_for_gemma2(
|
|
387
|
+
decoder_layer.pre_feedforward_layernorm
|
|
388
|
+
)
|
|
389
|
+
_patch_rms_norm_module_for_gemma2(
|
|
390
|
+
decoder_layer.post_feedforward_layernorm
|
|
391
|
+
)
|
|
394
392
|
|
|
395
393
|
|
|
396
394
|
def apply_liger_kernel_to_qwen2(
|
|
@@ -436,7 +434,6 @@ def apply_liger_kernel_to_qwen2(
|
|
|
436
434
|
if model is not None:
|
|
437
435
|
# The model instance already exists, so we need to additionally patch the
|
|
438
436
|
# instance variables that reference already-instantiated modules
|
|
439
|
-
config: PretrainedConfig = model.config
|
|
440
437
|
|
|
441
438
|
if hasattr(model, "model"):
|
|
442
439
|
# The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
|
|
@@ -445,22 +442,17 @@ def apply_liger_kernel_to_qwen2(
|
|
|
445
442
|
# Direct Qwen2Model
|
|
446
443
|
base_model = model
|
|
447
444
|
|
|
448
|
-
torch_dtype = config.torch_dtype
|
|
449
445
|
if rms_norm:
|
|
450
|
-
base_model.norm
|
|
451
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
452
|
-
).to(torch_dtype)
|
|
446
|
+
_patch_rms_norm_module(base_model.norm)
|
|
453
447
|
|
|
454
448
|
for decoder_layer in base_model.layers:
|
|
455
449
|
if swiglu:
|
|
456
|
-
|
|
450
|
+
_bind_method_to_module(
|
|
451
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
452
|
+
)
|
|
457
453
|
if rms_norm:
|
|
458
|
-
decoder_layer.input_layernorm
|
|
459
|
-
|
|
460
|
-
).to(torch_dtype)
|
|
461
|
-
decoder_layer.post_attention_layernorm = LigerRMSNorm(
|
|
462
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
463
|
-
).to(torch_dtype)
|
|
454
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
455
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
464
456
|
|
|
465
457
|
|
|
466
458
|
def apply_liger_kernel_to_qwen2_vl(
|
|
@@ -499,10 +491,9 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
499
491
|
|
|
500
492
|
# TODO: Support Qwen2-VL's multimodal RoPE implementation
|
|
501
493
|
|
|
502
|
-
LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma")
|
|
503
494
|
if rms_norm:
|
|
504
495
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
505
|
-
modeling_qwen2_vl.Qwen2RMSNorm =
|
|
496
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
506
497
|
if layer_norm:
|
|
507
498
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
508
499
|
if cross_entropy:
|
|
@@ -515,9 +506,6 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
515
506
|
if model is not None:
|
|
516
507
|
# The model instance already exists, so we need to additionally patch the
|
|
517
508
|
# instance variables that reference already-instantiated modules
|
|
518
|
-
config: PretrainedConfig = model.config
|
|
519
|
-
|
|
520
|
-
torch_dtype = config.torch_dtype
|
|
521
509
|
|
|
522
510
|
if hasattr(model, "model"):
|
|
523
511
|
# The case for Qwen2VLForConditionalGeneration.
|
|
@@ -530,27 +518,19 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
530
518
|
# Patch Qwen2VisionTransformerPretrainedModel
|
|
531
519
|
for vision_block in model.visual.blocks:
|
|
532
520
|
if layer_norm:
|
|
533
|
-
vision_block.norm1
|
|
534
|
-
|
|
535
|
-
)
|
|
536
|
-
vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
|
|
537
|
-
torch_dtype
|
|
538
|
-
)
|
|
521
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
522
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
539
523
|
|
|
540
524
|
if rms_norm:
|
|
541
|
-
base_model.norm
|
|
542
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
543
|
-
).to(torch_dtype)
|
|
525
|
+
_patch_rms_norm_module(base_model.norm)
|
|
544
526
|
for decoder_layer in base_model.layers:
|
|
545
527
|
if swiglu:
|
|
546
|
-
|
|
528
|
+
_bind_method_to_module(
|
|
529
|
+
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
530
|
+
)
|
|
547
531
|
if rms_norm:
|
|
548
|
-
decoder_layer.input_layernorm
|
|
549
|
-
|
|
550
|
-
).to(torch_dtype)
|
|
551
|
-
decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL(
|
|
552
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
553
|
-
).to(torch_dtype)
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
554
534
|
|
|
555
535
|
|
|
556
536
|
def apply_liger_kernel_to_phi3(
|
|
@@ -596,7 +576,6 @@ def apply_liger_kernel_to_phi3(
|
|
|
596
576
|
if model is not None:
|
|
597
577
|
# The model instance already exists, so we need to additionally patch the
|
|
598
578
|
# instance variables that reference already-instantiated modules
|
|
599
|
-
config: PretrainedConfig = model.config
|
|
600
579
|
|
|
601
580
|
if hasattr(model, "model"):
|
|
602
581
|
# The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
|
|
@@ -605,22 +584,17 @@ def apply_liger_kernel_to_phi3(
|
|
|
605
584
|
# Direct Phi3Model
|
|
606
585
|
base_model = model
|
|
607
586
|
|
|
608
|
-
torch_dtype = config.torch_dtype
|
|
609
587
|
if rms_norm:
|
|
610
|
-
base_model.norm
|
|
611
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
612
|
-
).to(torch_dtype)
|
|
588
|
+
_patch_rms_norm_module(base_model.norm)
|
|
613
589
|
|
|
614
590
|
for decoder_layer in base_model.layers:
|
|
615
591
|
if swiglu:
|
|
616
|
-
|
|
592
|
+
_bind_method_to_module(
|
|
593
|
+
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
|
594
|
+
)
|
|
617
595
|
if rms_norm:
|
|
618
|
-
decoder_layer.input_layernorm
|
|
619
|
-
|
|
620
|
-
).to(torch_dtype)
|
|
621
|
-
decoder_layer.post_attention_layernorm = LigerRMSNorm(
|
|
622
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
623
|
-
).to(torch_dtype)
|
|
596
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
597
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
624
598
|
|
|
625
599
|
|
|
626
600
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -32,16 +32,17 @@ License-File: LICENSE
|
|
|
32
32
|
License-File: NOTICE
|
|
33
33
|
Requires-Dist: torch>=2.1.2
|
|
34
34
|
Requires-Dist: triton>=2.3.0
|
|
35
|
-
Requires-Dist: transformers>=4.42.0
|
|
36
35
|
Provides-Extra: dev
|
|
36
|
+
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
|
37
37
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
38
38
|
Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
|
39
39
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
|
40
40
|
Requires-Dist: isort>=5.13.2; extra == "dev"
|
|
41
41
|
Requires-Dist: pytest>=7.1.2; extra == "dev"
|
|
42
42
|
Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
43
|
-
Requires-Dist: jupyter==1.0.0; extra == "dev"
|
|
44
43
|
Requires-Dist: seaborn; extra == "dev"
|
|
44
|
+
Provides-Extra: transformers
|
|
45
|
+
Requires-Dist: transformers~=4.0; extra == "transformers"
|
|
45
46
|
|
|
46
47
|
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
|
47
48
|
|
|
@@ -74,8 +75,8 @@ Requires-Dist: seaborn; extra == "dev"
|
|
|
74
75
|
</a>
|
|
75
76
|
</td>
|
|
76
77
|
<td style="padding: 10px;">
|
|
77
|
-
<a href="https://discord.gg/
|
|
78
|
-
<img src="https://dcbadge.vercel.app/api/server/
|
|
78
|
+
<a href="https://discord.gg/gpumode">
|
|
79
|
+
<img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
|
|
79
80
|
</a>
|
|
80
81
|
</td>
|
|
81
82
|
</tr>
|
|
@@ -151,7 +152,10 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
151
152
|
|
|
152
153
|
- `torch >= 2.1.2`
|
|
153
154
|
- `triton >= 2.3.0`
|
|
154
|
-
|
|
155
|
+
|
|
156
|
+
### Optional Dependencies
|
|
157
|
+
|
|
158
|
+
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
155
159
|
|
|
156
160
|
> **Note:**
|
|
157
161
|
> Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
|
|
@@ -174,7 +178,10 @@ To install from source:
|
|
|
174
178
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
|
175
179
|
cd Liger-Kernel
|
|
176
180
|
pip install -e .
|
|
181
|
+
# or if using transformers
|
|
182
|
+
pip install -e .[transformers]
|
|
177
183
|
```
|
|
184
|
+
|
|
178
185
|
## Getting Started
|
|
179
186
|
|
|
180
187
|
There are a couple of ways to apply Liger kernels, depending on the level of customization required.
|
|
@@ -271,9 +278,9 @@ loss.backward()
|
|
|
271
278
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
272
279
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
273
280
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
|
|
274
|
-
| Qwen2
|
|
281
|
+
| Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
275
282
|
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
276
|
-
| Phi3
|
|
283
|
+
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
277
284
|
|
|
278
285
|
|
|
279
286
|
|
|
@@ -1,24 +1,24 @@
|
|
|
1
1
|
liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
|
|
2
2
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
liger_kernel/ops/cross_entropy.py,sha256=6uoPScKpXJ7gdBlOpSnZcQ5fQe52JHYjUVsr_Bf4kCE,12317
|
|
4
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
|
5
|
-
liger_kernel/ops/geglu.py,sha256=
|
|
6
|
-
liger_kernel/ops/kl_div.py,sha256=
|
|
4
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=XLKDHBMbqD6nH2mfFLmA1UoU-N7CpKWHp4L3itWoHCs,9321
|
|
5
|
+
liger_kernel/ops/geglu.py,sha256=ErnNAgoMDCd8pqTh18Resl5JHCaRpRruH2jZ9_Y9CvA,4131
|
|
6
|
+
liger_kernel/ops/kl_div.py,sha256=qnmtFQwuO3FR7Ovup_DDzpkD1A1LpwOaWlcO6K9ysHk,8342
|
|
7
7
|
liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
|
|
8
8
|
liger_kernel/ops/rms_norm.py,sha256=4miEoDSdsc0GuhI3BpBRxt6iieFQcN2QnNp4o8PVB98,9921
|
|
9
9
|
liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
|
|
10
|
-
liger_kernel/ops/swiglu.py,sha256=
|
|
10
|
+
liger_kernel/ops/swiglu.py,sha256=qxNpfYUB9abS-v8yiuzQn9oYHA2P_l4wT19m8GkCa_c,2998
|
|
11
11
|
liger_kernel/ops/utils.py,sha256=Y5sbRuZVoswsMzITTTiFgITJN2QO0K4McAAUncE3UnE,1941
|
|
12
12
|
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
|
13
13
|
liger_kernel/transformers/__init__.py,sha256=UP5NP8yJhkFkjLVTkFRU0w0CA49hwdhqwmIgaBAEcj0,1148
|
|
14
|
-
liger_kernel/transformers/auto_model.py,sha256=
|
|
14
|
+
liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
|
|
15
15
|
liger_kernel/transformers/cross_entropy.py,sha256=gL30VByCSA_iQSkhV6no70x_IUqqFSTMJdytppico_w,804
|
|
16
16
|
liger_kernel/transformers/functional.py,sha256=gXviuzvWjkSLfNGUWLKDnp4s6ATpvz7309kov6JKp0Y,906
|
|
17
17
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=-07t8YRajZTrJOG2rUzt6Ur7kNuWgarWcqy7ou5Da8k,629
|
|
18
18
|
liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
|
|
19
|
-
liger_kernel/transformers/kl_div.py,sha256=
|
|
19
|
+
liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
|
|
20
20
|
liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
|
|
21
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
21
|
+
liger_kernel/transformers/monkey_patch.py,sha256=HtyeNNVJTOVN_UrI8piaG7_0An9-fgUXfIZfOlxx_os,28474
|
|
22
22
|
liger_kernel/transformers/rms_norm.py,sha256=4XfMQI6dORF7s_5qUqVHKWv-3IUomaimU2dg-NwnpoM,1035
|
|
23
23
|
liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
|
|
24
24
|
liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
|
|
@@ -34,9 +34,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=Va4uiZaVzCG2V7XKDfHjZyYTre5vPQM0
|
|
|
34
34
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=UajJdi49tUOfa68i2WHQ_2GZBF7d_N_uwOntER3bsl8,6607
|
|
35
35
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
36
36
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
|
37
|
-
liger_kernel-0.3.
|
|
38
|
-
liger_kernel-0.3.
|
|
39
|
-
liger_kernel-0.3.
|
|
40
|
-
liger_kernel-0.3.
|
|
41
|
-
liger_kernel-0.3.
|
|
42
|
-
liger_kernel-0.3.
|
|
37
|
+
liger_kernel-0.3.1.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
38
|
+
liger_kernel-0.3.1.dist-info/METADATA,sha256=fHMAk1Nur5qcuMidT0iXL5an0DIs9aG4HDFcqzD4Gms,25763
|
|
39
|
+
liger_kernel-0.3.1.dist-info/NOTICE,sha256=BXkXY9aWvEy_7MAB57zDu1z8uMYT1i1l9B6EpHuBa8s,173
|
|
40
|
+
liger_kernel-0.3.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
41
|
+
liger_kernel-0.3.1.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
42
|
+
liger_kernel-0.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|